mirror of
https://github.com/zkonduit/ezkl.git
synced 2026-04-25 03:01:17 -04:00
704 lines
22 KiB
Rust
704 lines
22 KiB
Rust
use super::errors::GraphError;
|
|
use super::quantize_float;
|
|
use crate::circuit::InputType;
|
|
use crate::fieldutils::integer_rep_to_felt;
|
|
#[cfg(not(target_arch = "wasm32"))]
|
|
use crate::graph::postgres::Client;
|
|
#[cfg(not(target_arch = "wasm32"))]
|
|
use crate::tensor::Tensor;
|
|
use crate::EZKL_BUF_CAPACITY;
|
|
use halo2curves::bn256::Fr as Fp;
|
|
#[cfg(feature = "python-bindings")]
|
|
use pyo3::prelude::*;
|
|
#[cfg(feature = "python-bindings")]
|
|
use pyo3::types::PyDict;
|
|
#[cfg(feature = "python-bindings")]
|
|
use pyo3::ToPyObject;
|
|
use serde::ser::SerializeStruct;
|
|
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
|
use std::io::BufReader;
|
|
use std::io::BufWriter;
|
|
use std::io::Read;
|
|
use std::panic::UnwindSafe;
|
|
#[cfg(not(target_arch = "wasm32"))]
|
|
use tract_onnx::tract_core::{
|
|
tract_data::{prelude::Tensor as TractTensor, TVec},
|
|
value::TValue,
|
|
};
|
|
#[cfg(not(target_arch = "wasm32"))]
|
|
use tract_onnx::tract_hir::tract_num_traits::ToPrimitive;
|
|
|
|
type Decimals = u8;
|
|
type Call = String;
|
|
type RPCUrl = String;
|
|
|
|
///
|
|
#[derive(Clone, Debug, PartialOrd, PartialEq)]
|
|
pub enum FileSourceInner {
|
|
/// Inner elements of float inputs coming from a file
|
|
Float(f64),
|
|
/// Inner elements of bool inputs coming from a file
|
|
Bool(bool),
|
|
/// Inner elements of inputs coming from a witness
|
|
Field(Fp),
|
|
}
|
|
|
|
impl FileSourceInner {
|
|
///
|
|
pub fn is_float(&self) -> bool {
|
|
matches!(self, FileSourceInner::Float(_))
|
|
}
|
|
///
|
|
pub fn is_bool(&self) -> bool {
|
|
matches!(self, FileSourceInner::Bool(_))
|
|
}
|
|
///
|
|
pub fn is_field(&self) -> bool {
|
|
matches!(self, FileSourceInner::Field(_))
|
|
}
|
|
}
|
|
|
|
impl Serialize for FileSourceInner {
|
|
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
|
where
|
|
S: Serializer,
|
|
{
|
|
match self {
|
|
FileSourceInner::Field(data) => data.serialize(serializer),
|
|
FileSourceInner::Bool(data) => data.serialize(serializer),
|
|
FileSourceInner::Float(data) => data.serialize(serializer),
|
|
}
|
|
}
|
|
}
|
|
|
|
// !!! ALWAYS USE JSON SERIALIZATION FOR GRAPH INPUT
|
|
// UNTAGGED ENUMS WONT WORK :( as highlighted here:
|
|
impl<'de> Deserialize<'de> for FileSourceInner {
|
|
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
|
where
|
|
D: Deserializer<'de>,
|
|
{
|
|
let this_json: Box<serde_json::value::RawValue> = Deserialize::deserialize(deserializer)?;
|
|
|
|
let bool_try: Result<bool, _> = serde_json::from_str(this_json.get());
|
|
if let Ok(t) = bool_try {
|
|
return Ok(FileSourceInner::Bool(t));
|
|
}
|
|
let float_try: Result<f64, _> = serde_json::from_str(this_json.get());
|
|
if let Ok(t) = float_try {
|
|
return Ok(FileSourceInner::Float(t));
|
|
}
|
|
let field_try: Result<Fp, _> = serde_json::from_str(this_json.get());
|
|
if let Ok(t) = field_try {
|
|
return Ok(FileSourceInner::Field(t));
|
|
}
|
|
|
|
Err(serde::de::Error::custom(
|
|
"failed to deserialize FileSourceInner",
|
|
))
|
|
}
|
|
}
|
|
|
|
/// Elements of inputs coming from a file
|
|
pub type FileSource = Vec<Vec<FileSourceInner>>;
|
|
|
|
impl FileSourceInner {
|
|
/// Create a new FileSourceInner
|
|
pub fn new_float(f: f64) -> Self {
|
|
FileSourceInner::Float(f)
|
|
}
|
|
/// Create a new FileSourceInner
|
|
pub fn new_field(f: Fp) -> Self {
|
|
FileSourceInner::Field(f)
|
|
}
|
|
/// Create a new FileSourceInner
|
|
pub fn new_bool(f: bool) -> Self {
|
|
FileSourceInner::Bool(f)
|
|
}
|
|
|
|
///
|
|
pub fn as_type(&mut self, input_type: &InputType) {
|
|
match self {
|
|
FileSourceInner::Float(f) => input_type.roundtrip(f),
|
|
FileSourceInner::Bool(_) => assert!(matches!(input_type, InputType::Bool)),
|
|
FileSourceInner::Field(_) => {}
|
|
}
|
|
}
|
|
|
|
/// Convert to a field element
|
|
pub fn to_field(&self, scale: crate::Scale) -> Fp {
|
|
match self {
|
|
FileSourceInner::Float(f) => integer_rep_to_felt(quantize_float(f, 0.0, scale).unwrap()),
|
|
FileSourceInner::Bool(f) => {
|
|
if *f {
|
|
Fp::one()
|
|
} else {
|
|
Fp::zero()
|
|
}
|
|
}
|
|
FileSourceInner::Field(f) => *f,
|
|
}
|
|
}
|
|
/// Convert to a float
|
|
pub fn to_float(&self) -> f64 {
|
|
match self {
|
|
FileSourceInner::Float(f) => *f,
|
|
FileSourceInner::Bool(f) => {
|
|
if *f {
|
|
1.0
|
|
} else {
|
|
0.0
|
|
}
|
|
}
|
|
FileSourceInner::Field(f) => crate::fieldutils::felt_to_integer_rep(*f) as f64,
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Inner elements of inputs/outputs coming from on-chain
|
|
#[derive(Clone, Debug, Deserialize, Serialize, Default, PartialOrd, PartialEq)]
|
|
pub struct OnChainSource {
|
|
/// Vector of calls to accounts
|
|
pub calls: Vec<CallsToAccount>,
|
|
/// RPC url
|
|
pub rpc: RPCUrl,
|
|
}
|
|
|
|
impl OnChainSource {
|
|
/// Create a new OnChainSource
|
|
pub fn new(calls: Vec<CallsToAccount>, rpc: RPCUrl) -> Self {
|
|
OnChainSource { calls, rpc }
|
|
}
|
|
}
|
|
|
|
#[cfg(not(target_arch = "wasm32"))]
|
|
/// Inner elements of inputs/outputs coming from postgres DB
|
|
#[derive(Clone, Debug, Deserialize, Serialize, Default, PartialOrd, PartialEq)]
|
|
pub struct PostgresSource {
|
|
/// postgres host
|
|
pub host: RPCUrl,
|
|
/// user to connect to postgres
|
|
pub user: String,
|
|
/// password to connect to postgres
|
|
pub password: String,
|
|
/// query to execute
|
|
pub query: String,
|
|
/// dbname
|
|
pub dbname: String,
|
|
/// port
|
|
pub port: String,
|
|
}
|
|
|
|
#[cfg(not(target_arch = "wasm32"))]
|
|
impl PostgresSource {
|
|
/// Create a new PostgresSource
|
|
pub fn new(
|
|
host: RPCUrl,
|
|
port: String,
|
|
user: String,
|
|
query: String,
|
|
dbname: String,
|
|
password: String,
|
|
) -> Self {
|
|
PostgresSource {
|
|
host,
|
|
user,
|
|
password,
|
|
query,
|
|
dbname,
|
|
port,
|
|
}
|
|
}
|
|
|
|
/// Fetch data from postgres
|
|
pub async fn fetch(&self) -> Result<Vec<Vec<pg_bigdecimal::PgNumeric>>, GraphError> {
|
|
// clone to move into thread
|
|
let user = self.user.clone();
|
|
let host = self.host.clone();
|
|
let query = self.query.clone();
|
|
let dbname = self.dbname.clone();
|
|
let port = self.port.clone();
|
|
let password = self.password.clone();
|
|
|
|
let config = if password.is_empty() {
|
|
format!(
|
|
"host={} user={} dbname={} port={}",
|
|
host, user, dbname, port
|
|
)
|
|
} else {
|
|
format!(
|
|
"host={} user={} dbname={} port={} password={}",
|
|
host, user, dbname, port, password
|
|
)
|
|
};
|
|
|
|
let mut client = Client::connect(&config).await?;
|
|
let mut res: Vec<pg_bigdecimal::PgNumeric> = Vec::new();
|
|
// extract rows from query
|
|
for row in client.query(&query, &[]).await? {
|
|
// extract features from row
|
|
for i in 0..row.len() {
|
|
res.push(row.get(i));
|
|
}
|
|
}
|
|
Ok(vec![res])
|
|
}
|
|
|
|
/// Fetch data from postgres and format it as a FileSource
|
|
pub async fn fetch_and_format_as_file(&self) -> Result<Vec<Vec<FileSourceInner>>, GraphError> {
|
|
Ok(self
|
|
.fetch()
|
|
.await?
|
|
.iter()
|
|
.map(|d| {
|
|
d.iter()
|
|
.map(|d| {
|
|
FileSourceInner::Float(
|
|
d.n.as_ref()
|
|
.unwrap()
|
|
.to_f64()
|
|
.ok_or("could not convert decimal to f64")
|
|
.unwrap(),
|
|
)
|
|
})
|
|
.collect()
|
|
})
|
|
.collect())
|
|
}
|
|
}
|
|
|
|
impl OnChainSource {
|
|
#[cfg(not(target_arch = "wasm32"))]
|
|
/// Create dummy local on-chain data to test the OnChain data source
|
|
pub async fn test_from_file_data(
|
|
data: &FileSource,
|
|
scales: Vec<crate::Scale>,
|
|
mut shapes: Vec<Vec<usize>>,
|
|
rpc: Option<&str>,
|
|
) -> Result<(Vec<Tensor<Fp>>, Self), GraphError> {
|
|
use crate::eth::{
|
|
evm_quantize, read_on_chain_inputs, test_on_chain_data, DEFAULT_ANVIL_ENDPOINT,
|
|
};
|
|
use log::debug;
|
|
|
|
// Set up local anvil instance for reading on-chain data
|
|
let (client, client_address) = crate::eth::setup_eth_backend(rpc, None).await?;
|
|
|
|
let mut scales = scales;
|
|
// set scales to 1 where data is a field element
|
|
for (idx, i) in data.iter().enumerate() {
|
|
if i.iter().all(|e| e.is_field()) {
|
|
scales[idx] = 0;
|
|
shapes[idx] = vec![i.len()];
|
|
}
|
|
}
|
|
|
|
let calls_to_accounts = test_on_chain_data(client.clone(), data).await?;
|
|
debug!("Calls to accounts: {:?}", calls_to_accounts);
|
|
let inputs =
|
|
read_on_chain_inputs(client.clone(), client_address, &calls_to_accounts).await?;
|
|
debug!("Inputs: {:?}", inputs);
|
|
|
|
let mut quantized_evm_inputs = vec![];
|
|
|
|
let mut prev = 0;
|
|
for (idx, i) in data.iter().enumerate() {
|
|
quantized_evm_inputs.extend(
|
|
evm_quantize(
|
|
client.clone(),
|
|
vec![scales[idx]; i.len()],
|
|
&(
|
|
inputs.0[prev..i.len()].to_vec(),
|
|
inputs.1[prev..i.len()].to_vec(),
|
|
),
|
|
)
|
|
.await?,
|
|
);
|
|
prev += i.len();
|
|
}
|
|
|
|
// on-chain data has already been quantized at this point. Just need to reshape it and push into tensor vector
|
|
let mut inputs: Vec<Tensor<Fp>> = vec![];
|
|
for (input, shape) in [quantized_evm_inputs].iter().zip(shapes) {
|
|
let mut t: Tensor<Fp> = input.iter().cloned().collect();
|
|
t.reshape(&shape)?;
|
|
inputs.push(t);
|
|
}
|
|
|
|
let used_rpc = rpc.unwrap_or(DEFAULT_ANVIL_ENDPOINT).to_string();
|
|
|
|
// Fill the input_data field of the GraphData struct
|
|
Ok((
|
|
inputs,
|
|
OnChainSource::new(calls_to_accounts.clone(), used_rpc),
|
|
))
|
|
}
|
|
}
|
|
|
|
/// Defines the view only calls to accounts to fetch the on-chain input data.
|
|
/// This data will be included as part of the first elements in the publicInputs
|
|
/// for the sol evm verifier and will be verifyWithDataAttestation.sol
|
|
#[derive(Clone, Debug, Deserialize, Serialize, Default, PartialOrd, PartialEq)]
|
|
pub struct CallsToAccount {
|
|
/// A vector of tuples, where index 0 of tuples
|
|
/// are the byte strings representing the ABI encoded function calls to
|
|
/// read the data from the address. This call must return a single
|
|
/// elementary type (<https://docs.soliditylang.org/en/v0.8.20/abi-spec.html#types>).
|
|
/// The second index of the tuple is the number of decimals for f32 conversion.
|
|
/// We don't support dynamic types currently.
|
|
pub call_data: Vec<(Call, Decimals)>,
|
|
/// Address of the contract to read the data from.
|
|
pub address: String,
|
|
}
|
|
/// Enum that defines source of the inputs/outputs to the EZKL model
|
|
#[derive(Clone, Debug, Serialize, PartialOrd, PartialEq)]
|
|
#[serde(untagged)]
|
|
pub enum DataSource {
|
|
/// .json File data source.
|
|
File(FileSource),
|
|
/// On-chain data source. The first element is the calls to the account, and the second is the RPC url.
|
|
OnChain(OnChainSource),
|
|
/// Postgres DB
|
|
#[cfg(not(target_arch = "wasm32"))]
|
|
DB(PostgresSource),
|
|
}
|
|
|
|
impl Default for DataSource {
|
|
fn default() -> Self {
|
|
DataSource::File(vec![vec![]])
|
|
}
|
|
}
|
|
|
|
impl From<FileSource> for DataSource {
|
|
fn from(data: FileSource) -> Self {
|
|
DataSource::File(data)
|
|
}
|
|
}
|
|
|
|
impl From<Vec<Vec<Fp>>> for DataSource {
|
|
fn from(data: Vec<Vec<Fp>>) -> Self {
|
|
DataSource::File(
|
|
data.iter()
|
|
.map(|e| e.iter().map(|e| FileSourceInner::Field(*e)).collect())
|
|
.collect(),
|
|
)
|
|
}
|
|
}
|
|
|
|
impl From<Vec<Vec<f64>>> for DataSource {
|
|
fn from(data: Vec<Vec<f64>>) -> Self {
|
|
DataSource::File(
|
|
data.iter()
|
|
.map(|e| e.iter().map(|e| FileSourceInner::Float(*e)).collect())
|
|
.collect(),
|
|
)
|
|
}
|
|
}
|
|
|
|
impl From<OnChainSource> for DataSource {
|
|
fn from(data: OnChainSource) -> Self {
|
|
DataSource::OnChain(data)
|
|
}
|
|
}
|
|
|
|
// !!! ALWAYS USE JSON SERIALIZATION FOR GRAPH INPUT
|
|
// UNTAGGED ENUMS WONT WORK :( as highlighted here:
|
|
impl<'de> Deserialize<'de> for DataSource {
|
|
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
|
where
|
|
D: Deserializer<'de>,
|
|
{
|
|
let this_json: Box<serde_json::value::RawValue> = Deserialize::deserialize(deserializer)?;
|
|
|
|
let first_try: Result<FileSource, _> = serde_json::from_str(this_json.get());
|
|
|
|
if let Ok(t) = first_try {
|
|
return Ok(DataSource::File(t));
|
|
}
|
|
let second_try: Result<OnChainSource, _> = serde_json::from_str(this_json.get());
|
|
if let Ok(t) = second_try {
|
|
return Ok(DataSource::OnChain(t));
|
|
}
|
|
#[cfg(not(target_arch = "wasm32"))]
|
|
{
|
|
let third_try: Result<PostgresSource, _> = serde_json::from_str(this_json.get());
|
|
if let Ok(t) = third_try {
|
|
return Ok(DataSource::DB(t));
|
|
}
|
|
}
|
|
|
|
Err(serde::de::Error::custom("failed to deserialize DataSource"))
|
|
}
|
|
}
|
|
|
|
/// Input to graph as a datasource
|
|
/// Always use JSON serialization for GraphData. Seriously.
|
|
#[derive(Clone, Debug, Deserialize, Default, PartialEq)]
|
|
pub struct GraphData {
|
|
/// Inputs to the model / computational graph (can be empty vectors if inputs are coming from on-chain).
|
|
pub input_data: DataSource,
|
|
/// Outputs of the model / computational graph (can be empty vectors if outputs are coming from on-chain).
|
|
pub output_data: Option<DataSource>,
|
|
}
|
|
|
|
impl UnwindSafe for GraphData {}
|
|
|
|
impl GraphData {
|
|
// not wasm
|
|
#[cfg(not(target_arch = "wasm32"))]
|
|
/// Convert the input data to tract data
|
|
pub fn to_tract_data(
|
|
&self,
|
|
shapes: &[Vec<usize>],
|
|
datum_types: &[tract_onnx::prelude::DatumType],
|
|
) -> Result<TVec<TValue>, GraphError> {
|
|
let mut inputs = TVec::new();
|
|
match &self.input_data {
|
|
DataSource::File(data) => {
|
|
for (i, input) in data.iter().enumerate() {
|
|
if !input.is_empty() {
|
|
let dt = datum_types[i];
|
|
let input = input.iter().map(|e| e.to_float()).collect::<Vec<f64>>();
|
|
let tt = TractTensor::from_shape(&shapes[i], &input)?;
|
|
let tt = tt.cast_to_dt(dt)?;
|
|
inputs.push(tt.into_owned().into());
|
|
}
|
|
}
|
|
}
|
|
_ => {
|
|
return Err(GraphError::InvalidDims(
|
|
0,
|
|
"non file data cannot be split into batches".to_string(),
|
|
))
|
|
}
|
|
}
|
|
Ok(inputs)
|
|
}
|
|
|
|
///
|
|
pub fn new(input_data: DataSource) -> Self {
|
|
GraphData {
|
|
input_data,
|
|
output_data: None,
|
|
}
|
|
}
|
|
|
|
/// Load the model input from a file
|
|
pub fn from_path(path: std::path::PathBuf) -> Result<Self, GraphError> {
|
|
let reader = std::fs::File::open(&path).map_err(|e| {
|
|
GraphError::ReadWriteFileError(path.display().to_string(), e.to_string())
|
|
})?;
|
|
let mut reader = BufReader::with_capacity(*EZKL_BUF_CAPACITY, reader);
|
|
let mut buf = String::new();
|
|
reader.read_to_string(&mut buf).map_err(|e| {
|
|
GraphError::ReadWriteFileError(path.display().to_string(), e.to_string())
|
|
})?;
|
|
let graph_input = serde_json::from_str(&buf)?;
|
|
Ok(graph_input)
|
|
}
|
|
|
|
/// Save the model input to a file
|
|
pub fn save(&self, path: std::path::PathBuf) -> Result<(), GraphError> {
|
|
let file = std::fs::File::create(path.clone()).map_err(|e| {
|
|
GraphError::ReadWriteFileError(path.display().to_string(), e.to_string())
|
|
})?;
|
|
// buf writer
|
|
let writer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, file);
|
|
serde_json::to_writer(writer, self)?;
|
|
Ok(())
|
|
}
|
|
|
|
///
|
|
pub async fn split_into_batches(
|
|
&self,
|
|
input_shapes: Vec<Vec<usize>>,
|
|
) -> Result<Vec<Self>, GraphError> {
|
|
// split input data into batches
|
|
let mut batched_inputs = vec![];
|
|
|
|
let iterable = match self {
|
|
GraphData {
|
|
input_data: DataSource::File(data),
|
|
output_data: _,
|
|
} => data.clone(),
|
|
GraphData {
|
|
input_data: DataSource::OnChain(_),
|
|
output_data: _,
|
|
} => {
|
|
return Err(GraphError::InvalidDims(
|
|
0,
|
|
"on-chain data cannot be split into batches".to_string(),
|
|
))
|
|
}
|
|
#[cfg(not(target_arch = "wasm32"))]
|
|
GraphData {
|
|
input_data: DataSource::DB(data),
|
|
output_data: _,
|
|
} => data.fetch_and_format_as_file().await?,
|
|
};
|
|
|
|
for (i, shape) in input_shapes.iter().enumerate() {
|
|
// ensure the input is evenly divisible by batch_size
|
|
let input_size = shape.clone().iter().product::<usize>();
|
|
let input = &iterable[i];
|
|
if input.len() % input_size != 0 {
|
|
return Err(GraphError::InvalidDims(
|
|
0,
|
|
"calibration data length must be evenly divisible by the original input_size"
|
|
.to_string(),
|
|
));
|
|
}
|
|
let mut batches = vec![];
|
|
for batch in input.chunks(input_size) {
|
|
batches.push(batch.to_vec());
|
|
}
|
|
batched_inputs.push(batches);
|
|
}
|
|
|
|
// now merge all the batches for each input into a vector of batches
|
|
// first assert each input has the same number of batches
|
|
let num_batches = if batched_inputs.is_empty() {
|
|
0
|
|
} else {
|
|
let num_batches = batched_inputs[0].len();
|
|
for input in batched_inputs.iter() {
|
|
assert_eq!(input.len(), num_batches);
|
|
}
|
|
num_batches
|
|
};
|
|
// now merge the batches
|
|
let mut input_batches = vec![];
|
|
for i in 0..num_batches {
|
|
let mut batch = vec![];
|
|
for input in batched_inputs.iter() {
|
|
batch.push(input[i].clone());
|
|
}
|
|
input_batches.push(DataSource::File(batch));
|
|
}
|
|
|
|
if input_batches.is_empty() {
|
|
input_batches.push(DataSource::File(vec![vec![]]));
|
|
}
|
|
|
|
// create a new GraphWitness for each batch
|
|
let batches = input_batches
|
|
.into_iter()
|
|
.map(GraphData::new)
|
|
.collect::<Vec<GraphData>>();
|
|
|
|
Ok(batches)
|
|
}
|
|
}
|
|
|
|
#[cfg(feature = "python-bindings")]
|
|
impl ToPyObject for CallsToAccount {
|
|
fn to_object(&self, py: Python) -> PyObject {
|
|
let dict = PyDict::new(py);
|
|
dict.set_item("account", &self.address).unwrap();
|
|
dict.set_item("call_data", &self.call_data).unwrap();
|
|
dict.to_object(py)
|
|
}
|
|
}
|
|
|
|
#[cfg(feature = "python-bindings")]
|
|
impl ToPyObject for DataSource {
|
|
fn to_object(&self, py: Python) -> PyObject {
|
|
match self {
|
|
DataSource::File(data) => data.to_object(py),
|
|
DataSource::OnChain(source) => {
|
|
let dict = PyDict::new(py);
|
|
dict.set_item("rpc_url", &source.rpc).unwrap();
|
|
dict.set_item("calls_to_accounts", &source.calls).unwrap();
|
|
dict.to_object(py)
|
|
}
|
|
DataSource::DB(source) => {
|
|
let dict = PyDict::new(py);
|
|
dict.set_item("host", &source.host).unwrap();
|
|
dict.set_item("user", &source.user).unwrap();
|
|
dict.set_item("query", &source.query).unwrap();
|
|
dict.to_object(py)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
#[cfg(feature = "python-bindings")]
|
|
use crate::pfsys::field_to_string;
|
|
|
|
#[cfg(feature = "python-bindings")]
|
|
impl ToPyObject for FileSourceInner {
|
|
fn to_object(&self, py: Python) -> PyObject {
|
|
match self {
|
|
FileSourceInner::Field(data) => field_to_string(data).to_object(py),
|
|
FileSourceInner::Bool(data) => data.to_object(py),
|
|
FileSourceInner::Float(data) => data.to_object(py),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Serialize for GraphData {
|
|
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
|
where
|
|
S: Serializer,
|
|
{
|
|
let mut state = serializer.serialize_struct("GraphData", 4)?;
|
|
state.serialize_field("input_data", &self.input_data)?;
|
|
state.serialize_field("output_data", &self.output_data)?;
|
|
state.end()
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
// this is for backwards compatibility with the old format
|
|
fn test_data_source_serialization_round_trip() {
|
|
let source = DataSource::from(vec![vec![0.053_262_424, 0.074_970_566, 0.052_355_476]]);
|
|
|
|
let serialized = serde_json::to_string(&source).unwrap();
|
|
|
|
const JSON: &str = r#"[[0.053262424,0.074970566,0.052355476]]"#;
|
|
|
|
assert_eq!(serialized, JSON);
|
|
|
|
let expect = serde_json::from_str::<DataSource>(JSON)
|
|
.map_err(|e| e.to_string())
|
|
.unwrap();
|
|
|
|
assert_eq!(expect, source);
|
|
}
|
|
|
|
#[test]
|
|
// this is for backwards compatibility with the old format
|
|
fn test_graph_input_serialization_round_trip() {
|
|
let file = GraphData::new(DataSource::from(vec![vec![
|
|
0.05326242372393608,
|
|
0.07497056573629379,
|
|
0.05235547572374344,
|
|
]]));
|
|
|
|
let serialized = serde_json::to_string(&file).unwrap();
|
|
|
|
const JSON: &str = r#"{"input_data":[[0.05326242372393608,0.07497056573629379,0.05235547572374344]],"output_data":null}"#;
|
|
|
|
assert_eq!(serialized, JSON);
|
|
|
|
let graph_input3 = serde_json::from_str::<GraphData>(JSON)
|
|
.map_err(|e| e.to_string())
|
|
.unwrap();
|
|
assert_eq!(graph_input3, file);
|
|
}
|
|
|
|
// test for the compatibility with the serialized elements from the mclbn256 library
|
|
#[test]
|
|
fn test_python_compat() {
|
|
let source = Fp::from_raw([18445520602771460712, 838677322461845011, 3079992810, 0]);
|
|
|
|
let original_addr = "0x000000000000000000000000b794f5ea0ba39494ce839613fffba74279579268";
|
|
|
|
assert_eq!(format!("{:?}", source), original_addr);
|
|
}
|
|
}
|