mirror of
https://github.com/zkonduit/ezkl.git
synced 2026-04-25 03:01:17 -04:00
refactor: split input and witness data types (#334)
This commit is contained in:
@@ -2,6 +2,7 @@ use std::any::Any;
|
||||
|
||||
use crate::{
|
||||
circuit::{self, layouts, Tolerance},
|
||||
fieldutils::{felt_to_i128, i128_to_felt},
|
||||
graph::scale_to_multiplier,
|
||||
tensor::{self, Tensor, TensorError, TensorType, ValTensor},
|
||||
};
|
||||
@@ -38,27 +39,32 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
|
||||
}
|
||||
|
||||
/// Matches a [Op] to an operation in the `tensor::ops` module.
|
||||
fn f(&self, inputs: &[Tensor<i128>]) -> Result<ForwardResult, TensorError> {
|
||||
fn f(&self, inputs: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
|
||||
let x = inputs[0].clone().map(|x| felt_to_i128(x));
|
||||
|
||||
let (res, intermediate_lookups) = match &self {
|
||||
HybridOp::Max { axes, .. } => (tensor::ops::max_axes(&inputs[0], axes)?, vec![]),
|
||||
HybridOp::Max { axes, .. } => (tensor::ops::max_axes(&x, axes)?, vec![]),
|
||||
HybridOp::MaxPool2d {
|
||||
padding,
|
||||
stride,
|
||||
pool_dims,
|
||||
..
|
||||
} => (
|
||||
tensor::ops::max_pool2d(&inputs[0], padding, stride, pool_dims)?,
|
||||
tensor::ops::max_pool2d(&x, padding, stride, pool_dims)?,
|
||||
vec![],
|
||||
),
|
||||
HybridOp::Min { axes, .. } => (tensor::ops::min_axes(&inputs[0], axes)?, vec![]),
|
||||
HybridOp::Min { axes, .. } => (tensor::ops::min_axes(&x, axes)?, vec![]),
|
||||
HybridOp::Softmax { scales } => {
|
||||
tensor::ops::nonlinearities::multi_dim_softmax(&inputs[0], scales.0, scales.1)
|
||||
tensor::ops::nonlinearities::multi_dim_softmax(&x, scales.0, scales.1)
|
||||
}
|
||||
HybridOp::RangeCheck(..) => (inputs[0].clone(), vec![]),
|
||||
HybridOp::RangeCheck(..) => (x.clone(), vec![]),
|
||||
};
|
||||
|
||||
// convert back to felt
|
||||
let output = res.map(|x| i128_to_felt(x));
|
||||
|
||||
Ok(ForwardResult {
|
||||
output: res,
|
||||
output,
|
||||
intermediate_lookups,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1396,23 +1396,19 @@ pub fn nonlinearity<F: PrimeField + TensorType + PartialOrd>(
|
||||
let w = region.assign(&config.lookup_input, x)?;
|
||||
|
||||
// extract integer_valuations
|
||||
let integer_evals: Tensor<i128> = w
|
||||
.get_int_evals()
|
||||
.map_err(|e| {
|
||||
error!("{}", e);
|
||||
halo2_proofs::plonk::Error::Synthesis
|
||||
})?
|
||||
.into_iter()
|
||||
.into();
|
||||
let felt_evals: Tensor<F> = w.get_felt_evals().map_err(|e| {
|
||||
error!("{}", e);
|
||||
halo2_proofs::plonk::Error::Synthesis
|
||||
})?;
|
||||
|
||||
// for key generation integer_evals will be empty and we need to return a set of unassigned values
|
||||
let output: Tensor<Value<F>> = match integer_evals.len() {
|
||||
let output: Tensor<Value<F>> = match felt_evals.len() {
|
||||
// if empty return an unknown val
|
||||
0 => Tensor::from((0..x.dims().iter().product::<usize>()).map(|_| Value::unknown())),
|
||||
// if not empty apply the nonlinearity !
|
||||
_ => {
|
||||
let x = Op::<F>::f(nl, &[integer_evals])?;
|
||||
x.output.map(|elem| Value::known(i128_to_felt(elem)))
|
||||
let x = Op::<F>::f(nl, &[felt_evals])?;
|
||||
x.output.map(|elem| Value::known(elem))
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ use std::error::Error;
|
||||
|
||||
use crate::{
|
||||
circuit::{layouts, utils},
|
||||
fieldutils::i128_to_felt,
|
||||
fieldutils::{felt_to_i128, i128_to_felt},
|
||||
graph::scale_to_multiplier,
|
||||
tensor::{self, Tensor, TensorError, TensorType},
|
||||
};
|
||||
@@ -32,10 +32,10 @@ pub enum LookupOp {
|
||||
impl LookupOp {
|
||||
/// a value which is always in the table
|
||||
pub fn default_pair<F: PrimeField + TensorType + PartialOrd>(&self) -> (F, F) {
|
||||
let x = vec![0_i128].into_iter().into();
|
||||
let x = vec![i128_to_felt(0_i128)].into_iter().into();
|
||||
(
|
||||
<F as TensorType>::zero().unwrap(),
|
||||
i128_to_felt(Op::<F>::f(self, &[x]).unwrap().output[0]),
|
||||
Op::<F>::f(self, &[x]).unwrap().output[0],
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -45,50 +45,51 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for LookupOp {
|
||||
self
|
||||
}
|
||||
/// Matches a [Op] to an operation in the `tensor::ops` module.
|
||||
fn f(&self, x: &[Tensor<i128>]) -> Result<ForwardResult, TensorError> {
|
||||
fn f(&self, x: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
|
||||
let x = x[0].clone().map(|x| felt_to_i128(x));
|
||||
let res = match &self {
|
||||
LookupOp::GreaterThan { a } => Ok(tensor::ops::nonlinearities::greater_than(
|
||||
&x[0],
|
||||
&x,
|
||||
f32::from(*a).into(),
|
||||
)),
|
||||
LookupOp::Div { denom } => Ok(tensor::ops::nonlinearities::const_div(
|
||||
&x[0],
|
||||
&x,
|
||||
f32::from(*denom).into(),
|
||||
)),
|
||||
LookupOp::Recip { scale } => {
|
||||
Ok(tensor::ops::nonlinearities::recip(&x[0], *scale as u32))
|
||||
}
|
||||
LookupOp::Recip { scale } => Ok(tensor::ops::nonlinearities::recip(&x, *scale as u32)),
|
||||
LookupOp::ReLU { scale } => {
|
||||
Ok(tensor::ops::nonlinearities::leakyrelu(&x[0], *scale, 0_f64))
|
||||
Ok(tensor::ops::nonlinearities::leakyrelu(&x, *scale, 0_f64))
|
||||
}
|
||||
|
||||
LookupOp::LeakyReLU { scale, slope } => Ok(tensor::ops::nonlinearities::leakyrelu(
|
||||
&x[0],
|
||||
&x,
|
||||
*scale,
|
||||
slope.0.into(),
|
||||
)),
|
||||
LookupOp::Sigmoid { scales } => Ok(tensor::ops::nonlinearities::sigmoid(
|
||||
&x[0], scales.0, scales.1,
|
||||
)),
|
||||
LookupOp::Sigmoid { scales } => {
|
||||
Ok(tensor::ops::nonlinearities::sigmoid(&x, scales.0, scales.1))
|
||||
}
|
||||
LookupOp::Sqrt { scales } => {
|
||||
Ok(tensor::ops::nonlinearities::sqrt(&x[0], scales.0, scales.1))
|
||||
Ok(tensor::ops::nonlinearities::sqrt(&x, scales.0, scales.1))
|
||||
}
|
||||
LookupOp::Rsqrt { scales } => {
|
||||
Ok(tensor::ops::nonlinearities::rsqrt(&x, scales.0, scales.1))
|
||||
}
|
||||
LookupOp::Rsqrt { scales } => Ok(tensor::ops::nonlinearities::rsqrt(
|
||||
&x[0], scales.0, scales.1,
|
||||
)),
|
||||
LookupOp::Tanh { scales } => {
|
||||
Ok(tensor::ops::nonlinearities::tanh(&x[0], scales.0, scales.1))
|
||||
Ok(tensor::ops::nonlinearities::tanh(&x, scales.0, scales.1))
|
||||
}
|
||||
LookupOp::Erf { scales } => {
|
||||
Ok(tensor::ops::nonlinearities::erffunc(&x, scales.0, scales.1))
|
||||
}
|
||||
LookupOp::Erf { scales } => Ok(tensor::ops::nonlinearities::erffunc(
|
||||
&x[0], scales.0, scales.1,
|
||||
)),
|
||||
LookupOp::Exp { scales } => {
|
||||
Ok(tensor::ops::nonlinearities::exp(&x[0], scales.0, scales.1))
|
||||
Ok(tensor::ops::nonlinearities::exp(&x, scales.0, scales.1))
|
||||
}
|
||||
}?;
|
||||
|
||||
let output = res.map(|x| i128_to_felt(x));
|
||||
|
||||
Ok(ForwardResult {
|
||||
output: res,
|
||||
output,
|
||||
intermediate_lookups: vec![],
|
||||
})
|
||||
}
|
||||
|
||||
@@ -3,7 +3,10 @@ use std::{any::Any, error::Error};
|
||||
use itertools::Itertools;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::tensor::{self, Tensor, TensorError, TensorType, ValTensor};
|
||||
use crate::{
|
||||
fieldutils::{felt_to_i128, i128_to_felt},
|
||||
tensor::{self, Tensor, TensorError, TensorType, ValTensor},
|
||||
};
|
||||
use halo2curves::ff::PrimeField;
|
||||
|
||||
use self::{lookup::LookupOp, region::RegionCtx};
|
||||
@@ -25,15 +28,15 @@ pub mod region;
|
||||
|
||||
/// A struct representing the result of a forward pass.
|
||||
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
|
||||
pub struct ForwardResult {
|
||||
pub(crate) output: Tensor<i128>,
|
||||
pub struct ForwardResult<F: PrimeField + TensorType + PartialOrd> {
|
||||
pub(crate) output: Tensor<F>,
|
||||
pub(crate) intermediate_lookups: Vec<Tensor<i128>>,
|
||||
}
|
||||
|
||||
/// An enum representing operations that can be represented as constraints in a circuit.
|
||||
pub trait Op<F: PrimeField + TensorType + PartialOrd>: std::fmt::Debug + Send + Sync + Any {
|
||||
/// Matches a [Op] to an operation in the `tensor::ops` module.
|
||||
fn f(&self, x: &[Tensor<i128>]) -> Result<ForwardResult, TensorError>;
|
||||
fn f(&self, x: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError>;
|
||||
/// Returns a string representation of the operation.
|
||||
fn as_string(&self) -> String;
|
||||
|
||||
@@ -97,7 +100,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for Input {
|
||||
self
|
||||
}
|
||||
|
||||
fn f(&self, x: &[Tensor<i128>]) -> Result<ForwardResult, TensorError> {
|
||||
fn f(&self, x: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
|
||||
Ok(ForwardResult {
|
||||
output: x[0].clone(),
|
||||
intermediate_lookups: vec![],
|
||||
@@ -143,7 +146,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for Rescaled<F> {
|
||||
fn as_any(&self) -> &dyn Any {
|
||||
self
|
||||
}
|
||||
fn f(&self, x: &[Tensor<i128>]) -> Result<ForwardResult, TensorError> {
|
||||
fn f(&self, x: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
|
||||
if self.scale.len() != x.len() {
|
||||
return Err(TensorError::DimMismatch("rescaled inputs".to_string()));
|
||||
}
|
||||
@@ -151,10 +154,10 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for Rescaled<F> {
|
||||
let mut rescaled_inputs = vec![];
|
||||
let inputs = &mut x.to_vec();
|
||||
for (i, ri) in inputs.iter_mut().enumerate() {
|
||||
rescaled_inputs.push(tensor::ops::nonlinearities::const_div(
|
||||
ri,
|
||||
self.scale[i].1 as f64,
|
||||
));
|
||||
let ri = ri.map(|x| felt_to_i128(x));
|
||||
let res = tensor::ops::nonlinearities::const_div(&ri, self.scale[i].1 as f64);
|
||||
let output = res.map(|x| i128_to_felt(x));
|
||||
rescaled_inputs.push(output);
|
||||
}
|
||||
Op::<F>::f(&*self.inner, &rescaled_inputs)
|
||||
}
|
||||
@@ -218,7 +221,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for Unknown {
|
||||
fn as_any(&self) -> &dyn Any {
|
||||
self
|
||||
}
|
||||
fn f(&self, _: &[Tensor<i128>]) -> Result<ForwardResult, TensorError> {
|
||||
fn f(&self, _: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
|
||||
Err(TensorError::WrongMethod)
|
||||
}
|
||||
|
||||
@@ -265,10 +268,10 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for Constant<F> {
|
||||
fn as_any(&self) -> &dyn Any {
|
||||
self
|
||||
}
|
||||
fn f(&self, _: &[Tensor<i128>]) -> Result<ForwardResult, TensorError> {
|
||||
let values = self.quantized_values.clone();
|
||||
let int_values = values.get_int_evals().unwrap();
|
||||
let output = Tensor::new(Some(&int_values), values.dims())?;
|
||||
fn f(&self, _: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
|
||||
let mut output = self.quantized_values.get_felt_evals().unwrap();
|
||||
// make sure its the right shape
|
||||
output.reshape(self.quantized_values.dims());
|
||||
|
||||
Ok(ForwardResult {
|
||||
output,
|
||||
|
||||
@@ -98,7 +98,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for PolyOp<F> {
|
||||
}
|
||||
|
||||
/// Matches a [Op] to an operation in the `tensor::ops` module.
|
||||
fn f(&self, inputs: &[Tensor<i128>]) -> Result<ForwardResult, TensorError> {
|
||||
fn f(&self, inputs: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
|
||||
let mut inputs = inputs.to_vec();
|
||||
let res = match &self {
|
||||
PolyOp::Resize { scale_factor } => tensor::ops::resize(&inputs[0], scale_factor),
|
||||
@@ -124,14 +124,14 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for PolyOp<F> {
|
||||
}
|
||||
PolyOp::Add { a } => {
|
||||
if let Some(a) = a {
|
||||
inputs.push(Tensor::new(Some(&a.get_int_evals().unwrap()), a.dims())?);
|
||||
inputs.push(Tensor::new(Some(&a.get_felt_evals().unwrap()), a.dims())?);
|
||||
}
|
||||
tensor::ops::add(&inputs)
|
||||
}
|
||||
PolyOp::Sub => tensor::ops::sub(&inputs),
|
||||
PolyOp::Mult { a } => {
|
||||
if let Some(a) = a {
|
||||
inputs.push(Tensor::new(Some(&a.get_int_evals().unwrap()), a.dims())?);
|
||||
inputs.push(Tensor::new(Some(&a.get_felt_evals().unwrap()), a.dims())?);
|
||||
}
|
||||
tensor::ops::mult(&inputs)
|
||||
}
|
||||
@@ -141,9 +141,9 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for PolyOp<F> {
|
||||
padding,
|
||||
stride,
|
||||
} => {
|
||||
inputs.push(Tensor::new(Some(&a.get_int_evals().unwrap()), a.dims())?);
|
||||
inputs.push(Tensor::new(Some(&a.get_felt_evals().unwrap()), a.dims())?);
|
||||
if let Some(b) = bias {
|
||||
inputs.push(Tensor::new(Some(&b.get_int_evals().unwrap()), b.dims())?);
|
||||
inputs.push(Tensor::new(Some(&b.get_felt_evals().unwrap()), b.dims())?);
|
||||
}
|
||||
tensor::ops::conv(&inputs, *padding, *stride)
|
||||
}
|
||||
@@ -154,9 +154,9 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for PolyOp<F> {
|
||||
output_padding,
|
||||
stride,
|
||||
} => {
|
||||
inputs.push(Tensor::new(Some(&a.get_int_evals().unwrap()), a.dims())?);
|
||||
inputs.push(Tensor::new(Some(&a.get_felt_evals().unwrap()), a.dims())?);
|
||||
if let Some(b) = bias {
|
||||
inputs.push(Tensor::new(Some(&b.get_int_evals().unwrap()), b.dims())?);
|
||||
inputs.push(Tensor::new(Some(&b.get_felt_evals().unwrap()), b.dims())?);
|
||||
}
|
||||
tensor::ops::deconv(&inputs, *padding, *output_padding, *stride)
|
||||
}
|
||||
@@ -170,7 +170,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for PolyOp<F> {
|
||||
return Err(TensorError::DimMismatch("pack inputs".to_string()));
|
||||
}
|
||||
|
||||
tensor::ops::pack(&inputs[0], *base as i128, *scale)
|
||||
tensor::ops::pack(&inputs[0], F::from(*base as u64), *scale)
|
||||
}
|
||||
PolyOp::Pow(u) => {
|
||||
if 1 != inputs.len() {
|
||||
@@ -218,9 +218,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for PolyOp<F> {
|
||||
layouts::resize(config, region, values[..].try_into()?, scale_factor)?
|
||||
}
|
||||
PolyOp::Iff => layouts::iff(config, region, values[..].try_into()?)?,
|
||||
PolyOp::Einsum { equation } => {
|
||||
layouts::einsum(config, region, &mut values, equation)?
|
||||
}
|
||||
PolyOp::Einsum { equation } => layouts::einsum(config, region, &mut values, equation)?,
|
||||
PolyOp::Gather { dim, index } => {
|
||||
tensor::ops::gather(&values[0].get_inner_tensor()?, *dim, index)?.into()
|
||||
}
|
||||
|
||||
@@ -59,7 +59,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Table<F> {
|
||||
let smallest = -base.pow(self.bits as u32 - 1);
|
||||
let largest = base.pow(self.bits as u32 - 1);
|
||||
|
||||
let inputs = Tensor::from(smallest..=largest);
|
||||
let inputs = Tensor::from(smallest..=largest).map(|x| i128_to_felt(x));
|
||||
let evals = Op::<F>::f(&self.nonlinearity, &[inputs.clone()])?;
|
||||
|
||||
self.is_assigned = true;
|
||||
@@ -75,14 +75,14 @@ impl<F: PrimeField + TensorType + PartialOrd> Table<F> {
|
||||
|| format!("nl_i_col row {}", row_offset),
|
||||
self.table_input,
|
||||
row_offset,
|
||||
|| Value::known(i128_to_felt::<F>(*input)),
|
||||
|| Value::known(*input),
|
||||
)?;
|
||||
|
||||
table.assign_cell(
|
||||
|| format!("nl_o_col row {}", row_offset),
|
||||
self.table_output,
|
||||
row_offset,
|
||||
|| Value::known(i128_to_felt::<F>(evals.output[row_offset])),
|
||||
|| Value::known(evals.output[row_offset]),
|
||||
)?;
|
||||
Ok(())
|
||||
})
|
||||
|
||||
20
src/eth.rs
20
src/eth.rs
@@ -1,4 +1,4 @@
|
||||
use crate::graph::input::{CallsToAccount, DataSource, GraphWitness};
|
||||
use crate::graph::input::{CallsToAccount, GraphWitness, WitnessSource};
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use crate::graph::GraphSettings;
|
||||
use crate::pfsys::evm::{DeploymentCode, EvmVerificationError};
|
||||
@@ -164,12 +164,12 @@ pub async fn deploy_da_verifier_via_solidity(
|
||||
let mut instance_idx = 0;
|
||||
let mut contract_instance_offset = 0;
|
||||
|
||||
if let DataSource::OnChain(source) = witness.input_data {
|
||||
if let WitnessSource::OnChain(source) = witness.input_data {
|
||||
for call in source.calls {
|
||||
calls_to_accounts.push(call);
|
||||
instance_idx += 1;
|
||||
}
|
||||
} else if let DataSource::File(source) = witness.input_data {
|
||||
} else if let WitnessSource::File(source) = witness.input_data {
|
||||
if settings.run_args.input_visibility.is_public() {
|
||||
instance_idx += source.len();
|
||||
for s in source {
|
||||
@@ -178,7 +178,7 @@ pub async fn deploy_da_verifier_via_solidity(
|
||||
}
|
||||
}
|
||||
|
||||
if let DataSource::OnChain(source) = witness.output_data {
|
||||
if let WitnessSource::OnChain(source) = witness.output_data {
|
||||
let output_scales = settings.model_output_scales;
|
||||
for call in source.calls {
|
||||
calls_to_accounts.push(call);
|
||||
@@ -393,7 +393,7 @@ pub fn get_provider(rpc_url: &str) -> Result<Provider<Http>, Box<dyn Error>> {
|
||||
/// the number of decimals of the floating point value on chain.
|
||||
pub async fn test_on_chain_data<M: 'static + Middleware>(
|
||||
client: Arc<M>,
|
||||
data: &Vec<Vec<f32>>,
|
||||
data: &[Vec<f32>],
|
||||
) -> Result<Vec<CallsToAccount>, Box<dyn Error>> {
|
||||
let (contract, decimals) = setup_test_contract(client.clone(), data).await?;
|
||||
|
||||
@@ -454,7 +454,7 @@ pub async fn evm_quantize<M: 'static + Middleware>(
|
||||
client: Arc<M>,
|
||||
scales: Vec<f64>,
|
||||
data: &(Vec<ethers::types::Bytes>, Vec<u8>),
|
||||
) -> Result<Vec<i128>, Box<dyn Error>> {
|
||||
) -> Result<Vec<Fr>, Box<dyn Error>> {
|
||||
// save the sol to a tmp file
|
||||
let mut sol_path = std::env::temp_dir();
|
||||
sol_path.push("quantizedata.sol");
|
||||
@@ -495,7 +495,11 @@ pub async fn evm_quantize<M: 'static + Middleware>(
|
||||
.call()
|
||||
.await;
|
||||
|
||||
let results = results.unwrap();
|
||||
let results = results
|
||||
.unwrap()
|
||||
.iter()
|
||||
.map(|x| crate::fieldutils::i128_to_felt(*x))
|
||||
.collect::<Vec<Fr>>();
|
||||
info!("evm quantization results: {:#?}", results,);
|
||||
Ok(results.to_vec())
|
||||
}
|
||||
@@ -513,7 +517,7 @@ fn get_sol_contract_factory<M: 'static + Middleware>(
|
||||
if size > MAX_RUNTIME_BYTECODE_SIZE {
|
||||
// `_runtime_bytecode` exceeds the limit
|
||||
panic!(
|
||||
"Solidity runtime bytecode size is: {:#?},
|
||||
"Solidity runtime bytecode size is: {:#?},
|
||||
which exceeds 24577 bytes limit.
|
||||
Try setting '--optimzer-runs 1' when generating the verifier
|
||||
so SOLC can optimize for the smallest deployment",
|
||||
|
||||
@@ -6,8 +6,8 @@ use crate::commands::{Cli, Commands, RunArgs};
|
||||
use crate::eth::{deploy_da_verifier_via_solidity, deploy_verifier_via_solidity};
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use crate::eth::{fix_verifier_sol, get_contract_artifacts, verify_proof_via_solidity};
|
||||
use crate::graph::input::{DataSource, GraphInput};
|
||||
use crate::graph::{scale_to_multiplier, GraphCircuit, GraphSettings, GraphWitness, Model};
|
||||
use crate::graph::input::{GraphInput, WitnessSource};
|
||||
use crate::graph::{GraphCircuit, GraphSettings, GraphWitness, Model};
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use crate::graph::{TestDataSource, TestSources, Visibility};
|
||||
use crate::pfsys::evm::aggregation::{AggregationCircuit, PoseidonTranscript};
|
||||
@@ -497,9 +497,9 @@ pub(crate) async fn gen_witness(
|
||||
let data = GraphInput::from_path(data)?;
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
circuit.load_inputs(&data).await?;
|
||||
circuit.load_graph_input(&data).await?;
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
circuit.load_inputs(&data)?;
|
||||
circuit.load_graph_input(&data)?;
|
||||
|
||||
let start_time = Instant::now();
|
||||
|
||||
@@ -516,22 +516,22 @@ pub(crate) async fn gen_witness(
|
||||
res.outputs.iter().map(|t| t.dims()).collect_vec()
|
||||
);
|
||||
|
||||
let output_scales = circuit.model.graph.get_output_scales();
|
||||
let output_scales = output_scales
|
||||
let input_witness: Vec<Vec<Fr>> = res
|
||||
.inputs
|
||||
.iter()
|
||||
.map(|scale| scale_to_multiplier(*scale));
|
||||
.map(|t| t.clone().into_iter().collect_vec())
|
||||
.collect();
|
||||
|
||||
let float_res: Vec<Vec<f32>> = res
|
||||
let output_witness: Vec<Vec<Fr>> = res
|
||||
.outputs
|
||||
.iter()
|
||||
.zip(output_scales)
|
||||
.map(|(t, scale)| t.iter().map(|e| ((*e as f64 / scale) as f32)).collect_vec())
|
||||
.map(|t| t.clone().into_iter().collect_vec())
|
||||
.collect();
|
||||
trace!("model forward pass output: {:?}", float_res);
|
||||
trace!("model forward pass output: {:?}", output_witness);
|
||||
|
||||
let witness = GraphWitness {
|
||||
input_data: data.input_data,
|
||||
output_data: DataSource::File(float_res),
|
||||
input_data: WitnessSource::File(input_witness),
|
||||
output_data: WitnessSource::File(output_witness),
|
||||
processed_inputs: res.processed_inputs,
|
||||
processed_params: res.processed_params,
|
||||
processed_outputs: res.processed_outputs,
|
||||
@@ -587,7 +587,7 @@ pub(crate) fn init_bar(len: u64) -> ProgressBar {
|
||||
)
|
||||
.unwrap()
|
||||
.progress_chars("##-");
|
||||
pb.set_style(sty.clone());
|
||||
pb.set_style(sty);
|
||||
|
||||
pb
|
||||
}
|
||||
@@ -653,7 +653,7 @@ pub(crate) async fn calibrate(
|
||||
|
||||
tokio::task::spawn(async move {
|
||||
circuit
|
||||
.load_inputs(&chunk)
|
||||
.load_graph_input(&chunk)
|
||||
.await
|
||||
.map_err(|_| "failed to load circuit inputs")
|
||||
.unwrap();
|
||||
@@ -762,9 +762,9 @@ pub(crate) async fn mock(
|
||||
let data = GraphWitness::from_path(data_path.clone())?;
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
circuit.load_data(&data, None).await?;
|
||||
circuit.load_graph_witness(&data, None).await?;
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
circuit.load_data(&data)?;
|
||||
circuit.load_graph_witness(&data)?;
|
||||
|
||||
let public_inputs = circuit.prepare_public_inputs(&data)?;
|
||||
|
||||
@@ -901,7 +901,7 @@ pub(crate) fn create_evm_data_attestation_verifier(
|
||||
|
||||
let data = GraphWitness::from_path(witness)?;
|
||||
|
||||
let output_data = if let DataSource::OnChain(source) = data.output_data {
|
||||
let output_data = if let WitnessSource::OnChain(source) = data.output_data {
|
||||
if !visibility.output.is_public() {
|
||||
todo!("we currently don't support private output data on chain")
|
||||
}
|
||||
@@ -914,7 +914,7 @@ pub(crate) fn create_evm_data_attestation_verifier(
|
||||
None
|
||||
};
|
||||
|
||||
let input_data = if let DataSource::OnChain(source) = data.input_data {
|
||||
let input_data = if let WitnessSource::OnChain(source) = data.input_data {
|
||||
if !visibility.input.is_public() {
|
||||
todo!("we currently don't support private input data on chain")
|
||||
}
|
||||
@@ -1096,7 +1096,7 @@ pub(crate) async fn setup_test_evm_witness(
|
||||
};
|
||||
|
||||
circuit
|
||||
.load_data(&data, Some(test_on_chain_witness))
|
||||
.load_graph_witness(&data, Some(test_on_chain_witness))
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
@@ -1119,7 +1119,7 @@ pub(crate) async fn prove(
|
||||
let circuit_settings = GraphSettings::load(&settings_path)?;
|
||||
let mut circuit = GraphCircuit::from_settings(&circuit_settings, &model_path, check_mode)?;
|
||||
|
||||
circuit.load_data(&data, None).await?;
|
||||
circuit.load_graph_witness(&data, None).await?;
|
||||
let public_inputs = circuit.prepare_public_inputs(&data)?;
|
||||
|
||||
let circuit_settings = circuit.settings.clone();
|
||||
@@ -1202,7 +1202,7 @@ pub(crate) async fn fuzz(
|
||||
let pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, GraphCircuit>(&circuit, ¶ms)
|
||||
.map_err(Box::<dyn Error>::from)?;
|
||||
|
||||
circuit.load_data(&data, None).await?;
|
||||
circuit.load_graph_witness(&data, None).await?;
|
||||
let public_inputs = circuit.prepare_public_inputs(&data)?;
|
||||
|
||||
let strategy = KZGSingleStrategy::new(¶ms);
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use halo2curves::bn256::Fr as Fp;
|
||||
#[cfg(feature = "python-bindings")]
|
||||
use pyo3::prelude::*;
|
||||
#[cfg(feature = "python-bindings")]
|
||||
@@ -20,8 +21,10 @@ type Decimals = u8;
|
||||
type Call = String;
|
||||
type RPCUrl = String;
|
||||
|
||||
/// Inner elements of inputs/outputs coming from a file
|
||||
pub type FileSourceInner = Vec<Vec<f32>>;
|
||||
/// Inner elements of inputs coming from a file
|
||||
pub type FileSourceInner = Vec<Vec<f64>>;
|
||||
/// Inner elements of witness coming from a witness
|
||||
pub type WitnessFileSourceInner = Vec<Vec<Fp>>;
|
||||
/// Inner elements of inputs/outputs coming from on-chain
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, Default, PartialOrd, PartialEq)]
|
||||
pub struct OnChainSourceInner {
|
||||
@@ -42,13 +45,14 @@ impl OnChainSourceInner {
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
/// Create dummy local on-chain data to test the OnChain data source
|
||||
pub async fn test_from_file_data(
|
||||
data: &FileSourceInner,
|
||||
data: &WitnessFileSourceInner,
|
||||
scales: Vec<u32>,
|
||||
shapes: Vec<Vec<usize>>,
|
||||
rpc: Option<&str>,
|
||||
) -> Result<(Vec<Tensor<i128>>, Self), Box<dyn std::error::Error>> {
|
||||
) -> Result<(Vec<Tensor<Fp>>, Self), Box<dyn std::error::Error>> {
|
||||
use crate::eth::{evm_quantize, read_on_chain_inputs, test_on_chain_data};
|
||||
use crate::graph::scale_to_multiplier;
|
||||
use itertools::Itertools;
|
||||
use log::debug;
|
||||
|
||||
// Set up local anvil instance for reading on-chain data
|
||||
@@ -56,13 +60,25 @@ impl OnChainSourceInner {
|
||||
|
||||
let address = client.address();
|
||||
|
||||
let calls_to_accounts = test_on_chain_data(client.clone(), data).await?;
|
||||
let scales: Vec<f64> = scales.into_iter().map(scale_to_multiplier).collect();
|
||||
|
||||
// unquantize data
|
||||
let float_data = data
|
||||
.iter()
|
||||
.zip(scales.iter())
|
||||
.map(|(t, scale)| {
|
||||
t.iter()
|
||||
.map(|e| ((crate::fieldutils::felt_to_i128(*e) as f64 / scale) as f32))
|
||||
.collect_vec()
|
||||
})
|
||||
.collect::<Vec<Vec<f32>>>();
|
||||
|
||||
let calls_to_accounts = test_on_chain_data(client.clone(), &float_data).await?;
|
||||
debug!("Calls to accounts: {:?}", calls_to_accounts);
|
||||
let inputs = read_on_chain_inputs(client.clone(), address, &calls_to_accounts).await?;
|
||||
debug!("Inputs: {:?}", inputs);
|
||||
|
||||
let mut quantized_evm_inputs = vec![];
|
||||
let scales: Vec<f64> = scales.into_iter().map(scale_to_multiplier).collect();
|
||||
|
||||
let mut prev = 0;
|
||||
for (idx, i) in data.iter().enumerate() {
|
||||
@@ -81,9 +97,9 @@ impl OnChainSourceInner {
|
||||
}
|
||||
|
||||
// on-chain data has already been quantized at this point. Just need to reshape it and push into tensor vector
|
||||
let mut inputs: Vec<Tensor<i128>> = vec![];
|
||||
let mut inputs: Vec<Tensor<Fp>> = vec![];
|
||||
for (input, shape) in vec![quantized_evm_inputs].iter().zip(shapes) {
|
||||
let mut t: Tensor<i128> = input.iter().cloned().collect();
|
||||
let mut t: Tensor<Fp> = input.iter().cloned().collect();
|
||||
t.reshape(&shape);
|
||||
inputs.push(t);
|
||||
}
|
||||
@@ -141,26 +157,29 @@ impl From<OnChainSourceInner> for DataSource {
|
||||
}
|
||||
|
||||
/// Enum that defines source of the inputs/outputs to the EZKL model
|
||||
/// used for f32 to f64 conversion
|
||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||
#[derive(Clone, Debug, Serialize, PartialOrd, PartialEq)]
|
||||
#[serde(untagged)]
|
||||
enum DataSourceF64 {
|
||||
pub enum WitnessSource {
|
||||
/// .json File data source.
|
||||
File(WitnessFileSourceInner),
|
||||
/// On-chain data source. The first element is the calls to the account, and the second is the RPC url.
|
||||
OnChain(OnChainSourceInner),
|
||||
File(Vec<Vec<f64>>),
|
||||
}
|
||||
impl Default for WitnessSource {
|
||||
fn default() -> Self {
|
||||
WitnessSource::File(vec![vec![]])
|
||||
}
|
||||
}
|
||||
|
||||
impl From<DataSource> for DataSourceF64 {
|
||||
fn from(source: DataSource) -> Self {
|
||||
match source {
|
||||
DataSource::File(data) => {
|
||||
let data = data
|
||||
.iter()
|
||||
.map(|v| v.iter().map(|&f| f as f64).collect::<Vec<_>>())
|
||||
.collect::<Vec<_>>();
|
||||
DataSourceF64::File(data)
|
||||
}
|
||||
DataSource::OnChain(source) => DataSourceF64::OnChain(source),
|
||||
}
|
||||
impl From<WitnessFileSourceInner> for WitnessSource {
|
||||
fn from(data: WitnessFileSourceInner) -> Self {
|
||||
WitnessSource::File(data)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<OnChainSourceInner> for WitnessSource {
|
||||
fn from(data: OnChainSourceInner) -> Self {
|
||||
WitnessSource::OnChain(data)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -170,9 +189,9 @@ impl From<DataSource> for DataSourceF64 {
|
||||
pub struct GraphWitness {
|
||||
/// Inputs to the model / computational graph (can be empty vectors if inputs are coming from on-chain).
|
||||
/// TODO: Add retrieve from on-chain functionality
|
||||
pub input_data: DataSource,
|
||||
pub input_data: WitnessSource,
|
||||
/// The expected output of the model (can be empty vectors if outputs are not being constrained).
|
||||
pub output_data: DataSource,
|
||||
pub output_data: WitnessSource,
|
||||
/// Optional hashes of the inputs (can be None if there are no commitments). Wrapped as Option for backwards compatibility
|
||||
pub processed_inputs: Option<ModuleForwardResult>,
|
||||
/// Optional hashes of the params (can be None if there are no commitments). Wrapped as Option for backwards compatibility
|
||||
@@ -183,7 +202,7 @@ pub struct GraphWitness {
|
||||
|
||||
impl GraphWitness {
|
||||
///
|
||||
pub fn new(input_data: DataSource, output_data: DataSource) -> Self {
|
||||
pub fn new(input_data: WitnessSource, output_data: WitnessSource) -> Self {
|
||||
GraphWitness {
|
||||
input_data,
|
||||
output_data,
|
||||
@@ -213,14 +232,6 @@ pub struct GraphInput {
|
||||
pub input_data: DataSource,
|
||||
}
|
||||
|
||||
impl From<GraphWitness> for GraphInput {
|
||||
fn from(witness: GraphWitness) -> Self {
|
||||
GraphInput {
|
||||
input_data: witness.input_data,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl GraphInput {
|
||||
///
|
||||
pub fn new(input_data: DataSource) -> Self {
|
||||
@@ -302,11 +313,7 @@ impl GraphInput {
|
||||
}
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
use halo2curves::{
|
||||
bn256::{Fr as Fp, G1Affine},
|
||||
ff::PrimeField,
|
||||
serde::SerdeObject,
|
||||
};
|
||||
use halo2curves::{bn256::G1Affine, ff::PrimeField, serde::SerdeObject};
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
/// converts fp into Vec<u64>
|
||||
@@ -400,6 +407,27 @@ impl ToPyObject for DataSource {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
impl ToPyObject for WitnessSource {
|
||||
fn to_object(&self, py: Python) -> PyObject {
|
||||
match self {
|
||||
WitnessSource::File(data) => {
|
||||
let field_elem: Vec<Vec<Vec<u64>>> = data
|
||||
.iter()
|
||||
.map(|x| x.iter().map(field_to_vecu64).collect())
|
||||
.collect();
|
||||
field_elem.to_object(py)
|
||||
}
|
||||
WitnessSource::OnChain(source) => {
|
||||
let dict = PyDict::new(py);
|
||||
dict.set_item("rpc_url", &source.rpc).unwrap();
|
||||
dict.set_item("calls_to_accounts", &source.calls).unwrap();
|
||||
dict.to_object(py)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
impl ToPyObject for GraphWitness {
|
||||
fn to_object(&self, py: Python) -> PyObject {
|
||||
@@ -459,8 +487,7 @@ impl Serialize for GraphInput {
|
||||
S: Serializer,
|
||||
{
|
||||
let mut state = serializer.serialize_struct("GraphInput", 4)?;
|
||||
let input_data: DataSourceF64 = self.input_data.clone().into();
|
||||
state.serialize_field("input_data", &input_data)?;
|
||||
state.serialize_field("input_data", &self.input_data)?;
|
||||
state.end()
|
||||
}
|
||||
}
|
||||
@@ -488,16 +515,37 @@ impl<'de> Deserialize<'de> for DataSource {
|
||||
}
|
||||
}
|
||||
|
||||
// !!! ALWAYS USE JSON SERIALIZATION FOR GRAPH INPUT
|
||||
// UNTAGGED ENUMS WONT WORK :( as highlighted here:
|
||||
impl<'de> Deserialize<'de> for WitnessSource {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let this_json: Box<serde_json::value::RawValue> = Deserialize::deserialize(deserializer)?;
|
||||
|
||||
let first_try: Result<WitnessFileSourceInner, _> = serde_json::from_str(this_json.get());
|
||||
|
||||
if let Ok(t) = first_try {
|
||||
return Ok(WitnessSource::File(t));
|
||||
}
|
||||
let second_try: Result<OnChainSourceInner, _> = serde_json::from_str(this_json.get());
|
||||
if let Ok(t) = second_try {
|
||||
return Ok(WitnessSource::OnChain(t));
|
||||
}
|
||||
|
||||
Err(serde::de::Error::custom("failed to deserialize DataSource"))
|
||||
}
|
||||
}
|
||||
|
||||
impl Serialize for GraphWitness {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
let mut state = serializer.serialize_struct("GraphWitness", 4)?;
|
||||
let input_data: DataSourceF64 = self.input_data.clone().into();
|
||||
let output_data: DataSourceF64 = self.output_data.clone().into();
|
||||
state.serialize_field("input_data", &input_data)?;
|
||||
state.serialize_field("output_data", &output_data)?;
|
||||
state.serialize_field("input_data", &self.input_data)?;
|
||||
state.serialize_field("output_data", &self.output_data)?;
|
||||
|
||||
if let Some(processed_inputs) = &self.processed_inputs {
|
||||
state.serialize_field("processed_inputs", &processed_inputs)?;
|
||||
@@ -540,9 +588,9 @@ mod tests {
|
||||
// this is for backwards compatibility with the old format
|
||||
fn test_graph_input_serialization_round_trip() {
|
||||
let file = GraphInput::new(DataSource::File(vec![vec![
|
||||
0.053_262_424,
|
||||
0.074_970_566,
|
||||
0.052_355_476,
|
||||
0.05326242372393608,
|
||||
0.07497056573629379,
|
||||
0.05235547572374344,
|
||||
]]));
|
||||
|
||||
let serialized = serde_json::to_string(&file).unwrap();
|
||||
@@ -555,7 +603,6 @@ mod tests {
|
||||
let graph_input3 = serde_json::from_str::<GraphInput>(JSON)
|
||||
.map_err(|e| e.to_string())
|
||||
.unwrap();
|
||||
println!("{:?}", graph_input3.input_data);
|
||||
assert_eq!(graph_input3, file);
|
||||
}
|
||||
}
|
||||
|
||||
191
src/graph/mod.rs
191
src/graph/mod.rs
@@ -11,11 +11,12 @@ pub mod utilities;
|
||||
/// Representations of a computational graph's variables.
|
||||
pub mod vars;
|
||||
|
||||
pub use input::{DataSource, GraphWitness};
|
||||
use halo2_proofs::circuit::Value;
|
||||
pub use input::{DataSource, GraphWitness, WitnessSource};
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use self::input::OnChainSourceInner;
|
||||
use self::input::{FileSourceInner, GraphInput};
|
||||
use self::input::{FileSourceInner, GraphInput, WitnessFileSourceInner};
|
||||
use crate::circuit::lookup::LookupOp;
|
||||
use crate::circuit::modules::ModulePlanner;
|
||||
use crate::circuit::CheckMode;
|
||||
@@ -24,7 +25,7 @@ use crate::fieldutils::i128_to_felt;
|
||||
use crate::graph::modules::ModuleInstanceOffset;
|
||||
use crate::tensor::{Tensor, ValTensor};
|
||||
use halo2_proofs::{
|
||||
circuit::{Layouter, Value},
|
||||
circuit::Layouter,
|
||||
plonk::{Circuit, ConstraintSystem, Error as PlonkError},
|
||||
};
|
||||
use halo2curves::bn256::{self, Fr as Fp};
|
||||
@@ -99,9 +100,9 @@ const MAX_PUBLIC_SRS: u32 = bn256::Fr::S - 2;
|
||||
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
|
||||
pub struct ForwardResult {
|
||||
/// The inputs of the forward pass
|
||||
pub inputs: Vec<Tensor<i128>>,
|
||||
pub inputs: Vec<Tensor<Fp>>,
|
||||
/// The output of the forward pass
|
||||
pub outputs: Vec<Tensor<i128>>,
|
||||
pub outputs: Vec<Tensor<Fp>>,
|
||||
/// Any hashes of inputs generated during the forward pass
|
||||
pub processed_inputs: Option<ModuleForwardResult>,
|
||||
/// Any hashes of params generated during the forward pass
|
||||
@@ -173,9 +174,9 @@ pub struct GraphCircuit {
|
||||
/// The model / graph of computations.
|
||||
pub model: Model,
|
||||
/// Vector of input tensors to the model / graph of computations.
|
||||
pub inputs: Vec<Tensor<i128>>,
|
||||
pub inputs: Vec<Tensor<Fp>>,
|
||||
/// Vector of input tensors to the model / graph of computations.
|
||||
pub outputs: Vec<Tensor<i128>>,
|
||||
pub outputs: Vec<Tensor<Fp>>,
|
||||
/// The settings of the model / graph of computations.
|
||||
pub settings: GraphSettings,
|
||||
/// The settings of the model's modules.
|
||||
@@ -230,9 +231,9 @@ impl GraphCircuit {
|
||||
check_mode: CheckMode,
|
||||
) -> Result<GraphCircuit, Box<dyn std::error::Error>> {
|
||||
// placeholder dummy inputs - must call prepare_public_inputs to load data afterwards
|
||||
let mut inputs: Vec<Tensor<i128>> = vec![];
|
||||
let mut inputs: Vec<Tensor<Fp>> = vec![];
|
||||
for shape in model.graph.input_shapes() {
|
||||
let t: Tensor<i128> = Tensor::new(None, &shape).unwrap();
|
||||
let t: Tensor<Fp> = Tensor::new(None, &shape).unwrap();
|
||||
inputs.push(t);
|
||||
}
|
||||
|
||||
@@ -277,9 +278,9 @@ impl GraphCircuit {
|
||||
check_mode: CheckMode,
|
||||
) -> Result<GraphCircuit, Box<dyn std::error::Error>> {
|
||||
// placeholder dummy inputs - must call prepare_public_inputs to load data afterwards
|
||||
let mut inputs: Vec<Tensor<i128>> = vec![];
|
||||
let mut inputs: Vec<Tensor<Fp>> = vec![];
|
||||
for shape in model.graph.input_shapes() {
|
||||
let t: Tensor<i128> = Tensor::new(None, &shape).unwrap();
|
||||
let t: Tensor<Fp> = Tensor::new(None, &shape).unwrap();
|
||||
inputs.push(t);
|
||||
}
|
||||
|
||||
@@ -299,9 +300,14 @@ impl GraphCircuit {
|
||||
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
/// load inputs and outputs for the model
|
||||
pub fn load_data(&mut self, data: &GraphWitness) -> Result<(), Box<dyn std::error::Error>> {
|
||||
self.load_inputs(&data.clone().into())?;
|
||||
self.load_outputs(data)?;
|
||||
pub fn load_graph_witness(
|
||||
&mut self,
|
||||
data: &GraphWitness,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
self.inputs =
|
||||
self.process_witness_source(&data.input_data, self.model.graph.input_shapes())?;
|
||||
self.outputs =
|
||||
self.process_witness_source(&data.output_data, self.model.graph.output_shapes())?;
|
||||
// load the module settings
|
||||
self.module_settings = ModuleSettings::from(data);
|
||||
|
||||
@@ -310,7 +316,7 @@ impl GraphCircuit {
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
/// load inputs and outputs for the model
|
||||
pub async fn load_data(
|
||||
pub async fn load_graph_witness(
|
||||
&mut self,
|
||||
data: &GraphWitness,
|
||||
test_on_chain_data: Option<TestOnChainData>,
|
||||
@@ -322,8 +328,20 @@ impl GraphCircuit {
|
||||
self.populate_on_chain_test_data(&mut data, test_path)
|
||||
.await?;
|
||||
} else {
|
||||
self.load_inputs(&data.clone().into()).await?;
|
||||
self.load_outputs(&data).await?;
|
||||
self.inputs = self
|
||||
.process_witness_source(
|
||||
&data.input_data,
|
||||
self.model.graph.input_shapes(),
|
||||
self.model.graph.get_input_scales(),
|
||||
)
|
||||
.await?;
|
||||
self.outputs = self
|
||||
.process_witness_source(
|
||||
&data.output_data,
|
||||
self.model.graph.output_shapes(),
|
||||
self.model.graph.get_output_scales(),
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
// load the module settings
|
||||
@@ -358,11 +376,7 @@ impl GraphCircuit {
|
||||
|
||||
let mut pi_inner: Vec<Vec<Fp>> = public_inputs
|
||||
.iter()
|
||||
.map(|i| {
|
||||
i.iter()
|
||||
.map(|e| i128_to_felt::<Fp>(*e))
|
||||
.collect::<Vec<Fp>>()
|
||||
})
|
||||
.map(|i| i.clone().into_iter().collect::<Vec<Fp>>())
|
||||
.collect::<Vec<Vec<Fp>>>();
|
||||
|
||||
let module_instances =
|
||||
@@ -377,7 +391,10 @@ impl GraphCircuit {
|
||||
|
||||
///
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
pub fn load_inputs(&mut self, data: &GraphInput) -> Result<(), Box<dyn std::error::Error>> {
|
||||
pub fn load_graph_input(
|
||||
&mut self,
|
||||
data: &GraphInput,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let shapes = self.model.graph.input_shapes();
|
||||
let scales = vec![self.settings.run_args.scale; shapes.len()];
|
||||
self.inputs = self.process_data_source(&data.input_data, shapes, scales)?;
|
||||
@@ -386,7 +403,7 @@ impl GraphCircuit {
|
||||
|
||||
///
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub async fn load_inputs(
|
||||
pub async fn load_graph_input(
|
||||
&mut self,
|
||||
data: &GraphInput,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
@@ -395,29 +412,7 @@ impl GraphCircuit {
|
||||
self.inputs = self
|
||||
.process_data_source(&data.input_data, shapes, scales)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
///
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
pub fn load_outputs(&mut self, data: &GraphWitness) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let out_scales = self.model.graph.get_output_scales();
|
||||
let shapes = self.model.graph.output_shapes();
|
||||
self.outputs = self.process_data_source(&data.output_data, shapes, out_scales)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
///
|
||||
pub async fn load_outputs(
|
||||
&mut self,
|
||||
data: &GraphWitness,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let out_scales = self.model.graph.get_output_scales();
|
||||
let shapes = self.model.graph.output_shapes();
|
||||
self.outputs = self
|
||||
.process_data_source(&data.output_data, shapes, out_scales)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -428,7 +423,7 @@ impl GraphCircuit {
|
||||
data: &DataSource,
|
||||
shapes: Vec<Vec<usize>>,
|
||||
scales: Vec<u32>,
|
||||
) -> Result<Vec<Tensor<i128>>, Box<dyn std::error::Error>> {
|
||||
) -> Result<Vec<Tensor<Fp>>, Box<dyn std::error::Error>> {
|
||||
match &data {
|
||||
DataSource::OnChain(_) => {
|
||||
panic!("Cannot use on-chain data source as input for wasm rn.")
|
||||
@@ -437,6 +432,21 @@ impl GraphCircuit {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
/// Process the data source for the model
|
||||
fn process_witness_source(
|
||||
&mut self,
|
||||
data: &WitnessSource,
|
||||
shapes: Vec<Vec<usize>>,
|
||||
) -> Result<Vec<Tensor<Fp>>, Box<dyn std::error::Error>> {
|
||||
match &data {
|
||||
WitnessSource::OnChain(_) => {
|
||||
panic!("Cannot use on-chain data source as input for wasm rn.")
|
||||
}
|
||||
WitnessSource::File(file_data) => self.load_witness_file_data(file_data, &shapes),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
/// Process the data source for the model
|
||||
async fn process_data_source(
|
||||
@@ -444,7 +454,7 @@ impl GraphCircuit {
|
||||
data: &DataSource,
|
||||
shapes: Vec<Vec<usize>>,
|
||||
scales: Vec<u32>,
|
||||
) -> Result<Vec<Tensor<i128>>, Box<dyn std::error::Error>> {
|
||||
) -> Result<Vec<Tensor<Fp>>, Box<dyn std::error::Error>> {
|
||||
match &data {
|
||||
DataSource::OnChain(source) => {
|
||||
let mut per_item_scale = vec![];
|
||||
@@ -458,6 +468,27 @@ impl GraphCircuit {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
/// Process the data source for the model
|
||||
async fn process_witness_source(
|
||||
&mut self,
|
||||
data: &WitnessSource,
|
||||
shapes: Vec<Vec<usize>>,
|
||||
scales: Vec<u32>,
|
||||
) -> Result<Vec<Tensor<Fp>>, Box<dyn std::error::Error>> {
|
||||
match &data {
|
||||
WitnessSource::OnChain(source) => {
|
||||
let mut per_item_scale = vec![];
|
||||
for (i, shape) in shapes.iter().enumerate() {
|
||||
per_item_scale.extend(vec![scales[i]; shape.iter().product::<usize>()]);
|
||||
}
|
||||
self.load_on_chain_data(source.clone(), &shapes, per_item_scale)
|
||||
.await
|
||||
}
|
||||
WitnessSource::File(file_data) => self.load_witness_file_data(file_data, &shapes),
|
||||
}
|
||||
}
|
||||
|
||||
/// Prepare on chain test data
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub async fn load_on_chain_data(
|
||||
@@ -465,7 +496,7 @@ impl GraphCircuit {
|
||||
source: OnChainSourceInner,
|
||||
shapes: &Vec<Vec<usize>>,
|
||||
scales: Vec<u32>,
|
||||
) -> Result<Vec<Tensor<i128>>, Box<dyn std::error::Error>> {
|
||||
) -> Result<Vec<Tensor<Fp>>, Box<dyn std::error::Error>> {
|
||||
use crate::eth::{evm_quantize, read_on_chain_inputs, setup_eth_backend};
|
||||
let (_, client) = setup_eth_backend(Some(&source.rpc)).await?;
|
||||
let inputs = read_on_chain_inputs(client.clone(), client.address(), &source.calls).await?;
|
||||
@@ -477,9 +508,9 @@ impl GraphCircuit {
|
||||
)
|
||||
.await?;
|
||||
// on-chain data has already been quantized at this point. Just need to reshape it and push into tensor vector
|
||||
let mut inputs: Vec<Tensor<i128>> = vec![];
|
||||
let mut inputs: Vec<Tensor<Fp>> = vec![];
|
||||
for (input, shape) in vec![quantized_evm_inputs].iter().zip(shapes) {
|
||||
let mut t: Tensor<i128> = input.iter().cloned().collect();
|
||||
let mut t: Tensor<Fp> = input.iter().cloned().collect();
|
||||
t.reshape(shape);
|
||||
inputs.push(t);
|
||||
}
|
||||
@@ -493,16 +524,16 @@ impl GraphCircuit {
|
||||
file_data: &FileSourceInner,
|
||||
shapes: &Vec<Vec<usize>>,
|
||||
scales: Vec<u32>,
|
||||
) -> Result<Vec<Tensor<i128>>, Box<dyn std::error::Error>> {
|
||||
) -> Result<Vec<Tensor<Fp>>, Box<dyn std::error::Error>> {
|
||||
// quantize the supplied data using the provided scale.
|
||||
let mut data: Vec<Tensor<i128>> = vec![];
|
||||
let mut data: Vec<Tensor<Fp>> = vec![];
|
||||
for ((d, shape), scale) in file_data.iter().zip(shapes).zip(scales) {
|
||||
let t: Vec<i128> = d
|
||||
let t: Vec<Fp> = d
|
||||
.par_iter()
|
||||
.map(|x| quantize_float(x, 0.0, scale).unwrap())
|
||||
.map(|x| i128_to_felt(quantize_float(x, 0.0, scale).unwrap()))
|
||||
.collect();
|
||||
|
||||
let mut t: Tensor<i128> = t.into_iter().into();
|
||||
let mut t: Tensor<Fp> = t.into_iter().into();
|
||||
t.reshape(shape);
|
||||
|
||||
data.push(t);
|
||||
@@ -510,6 +541,22 @@ impl GraphCircuit {
|
||||
Ok(data)
|
||||
}
|
||||
|
||||
///
|
||||
pub fn load_witness_file_data(
|
||||
&mut self,
|
||||
file_data: &WitnessFileSourceInner,
|
||||
shapes: &Vec<Vec<usize>>,
|
||||
) -> Result<Vec<Tensor<Fp>>, Box<dyn std::error::Error>> {
|
||||
// quantize the supplied data using the provided scale.
|
||||
let mut data: Vec<Tensor<Fp>> = vec![];
|
||||
for (d, shape) in file_data.iter().zip(shapes) {
|
||||
let mut t: Tensor<Fp> = d.clone().into_iter().into();
|
||||
t.reshape(shape);
|
||||
data.push(t);
|
||||
}
|
||||
Ok(data)
|
||||
}
|
||||
|
||||
/// Calibrate the circuit to the supplied data.
|
||||
pub fn calibrate(&mut self) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let res = self.forward()?;
|
||||
@@ -575,7 +622,7 @@ impl GraphCircuit {
|
||||
if visibility.params.requires_processing() {
|
||||
let params = self.model.get_all_consts();
|
||||
let flattened_params = flatten_valtensors(params)?
|
||||
.get_int_evals()?
|
||||
.get_felt_evals()?
|
||||
.into_iter()
|
||||
.into();
|
||||
processed_params = Some(GraphModules::forward(
|
||||
@@ -638,8 +685,8 @@ impl GraphCircuit {
|
||||
}
|
||||
|
||||
let input_data = match &data.input_data {
|
||||
DataSource::File(input_data) => input_data,
|
||||
DataSource::OnChain(_) => panic!(
|
||||
WitnessSource::File(input_data) => input_data,
|
||||
WitnessSource::OnChain(_) => panic!(
|
||||
"Cannot use on-chain data source as input for on-chain test.
|
||||
Will manually populate on-chain data from file source instead"
|
||||
),
|
||||
@@ -647,7 +694,7 @@ impl GraphCircuit {
|
||||
// Get the flatten length of input_data
|
||||
let length = input_data.iter().map(|x| x.len()).sum();
|
||||
let scales = vec![self.settings.run_args.scale; length];
|
||||
let datam: (Vec<Tensor<i128>>, OnChainSourceInner) =
|
||||
let datam: (Vec<Tensor<Fp>>, OnChainSourceInner) =
|
||||
OnChainSourceInner::test_from_file_data(
|
||||
input_data,
|
||||
scales,
|
||||
@@ -658,7 +705,13 @@ impl GraphCircuit {
|
||||
self.inputs = datam.0;
|
||||
data.input_data = datam.1.into();
|
||||
} else {
|
||||
self.load_inputs(&data.clone().into()).await?;
|
||||
self.inputs = self
|
||||
.process_witness_source(
|
||||
&data.input_data,
|
||||
self.model.graph.input_shapes(),
|
||||
self.model.graph.get_input_scales(),
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
if matches!(
|
||||
test_on_chain_data.data_sources.output,
|
||||
@@ -670,13 +723,13 @@ impl GraphCircuit {
|
||||
}
|
||||
|
||||
let output_data = match &data.output_data {
|
||||
DataSource::File(output_data) => output_data,
|
||||
DataSource::OnChain(_) => panic!(
|
||||
WitnessSource::File(output_data) => output_data,
|
||||
WitnessSource::OnChain(_) => panic!(
|
||||
"Cannot use on-chain data source as output for on-chain test.
|
||||
Will manually populate on-chain data from file source instead"
|
||||
),
|
||||
};
|
||||
let datum: (Vec<Tensor<i128>>, OnChainSourceInner) =
|
||||
let datum: (Vec<Tensor<Fp>>, OnChainSourceInner) =
|
||||
OnChainSourceInner::test_from_file_data(
|
||||
output_data,
|
||||
self.model.graph.get_output_scales(),
|
||||
@@ -687,7 +740,13 @@ impl GraphCircuit {
|
||||
self.outputs = datum.0;
|
||||
data.output_data = datum.1.into();
|
||||
} else {
|
||||
self.load_outputs(data).await?;
|
||||
self.outputs = self
|
||||
.process_witness_source(
|
||||
&data.input_data,
|
||||
self.model.graph.input_shapes(),
|
||||
self.model.graph.get_output_scales(),
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
// Save the updated GraphInput struct to the data_path
|
||||
data.save(test_on_chain_data.data)?;
|
||||
@@ -758,7 +817,7 @@ impl Circuit<Fp> for GraphCircuit {
|
||||
let mut inputs = self
|
||||
.inputs
|
||||
.iter()
|
||||
.map(|i| ValTensor::from(<Tensor<i128> as Into<Tensor<Value<Fp>>>>::into(i.clone())))
|
||||
.map(|i| ValTensor::from(i.map(|x| Value::known(x))))
|
||||
.collect::<Vec<ValTensor<Fp>>>();
|
||||
|
||||
let mut instance_offset = ModuleInstanceOffset::new();
|
||||
|
||||
@@ -8,6 +8,7 @@ use crate::circuit::hybrid::HybridOp;
|
||||
use crate::circuit::region::RegionCtx;
|
||||
use crate::circuit::Input;
|
||||
use crate::circuit::Unknown;
|
||||
use crate::fieldutils::felt_to_i128;
|
||||
use crate::{
|
||||
circuit::{lookup::LookupOp, BaseConfig as PolyConfig, CheckMode, Op},
|
||||
commands::RunArgs,
|
||||
@@ -40,7 +41,7 @@ use tract_onnx::prelude::Framework;
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ForwardResult {
|
||||
/// The outputs of the forward pass.
|
||||
pub outputs: Vec<Tensor<i128>>,
|
||||
pub outputs: Vec<Tensor<Fp>>,
|
||||
/// The maximum value of any input to a lookup operation.
|
||||
pub max_lookup_inputs: i128,
|
||||
}
|
||||
@@ -182,6 +183,14 @@ impl ParsedNodes {
|
||||
.collect_vec()
|
||||
}
|
||||
|
||||
/// Returns the fixed point scale of the computational graph's inputs
|
||||
pub fn get_input_scales(&self) -> Vec<u32> {
|
||||
let input_nodes = self.inputs.iter();
|
||||
input_nodes
|
||||
.flat_map(|o| self.nodes.get(o).unwrap().out_scales())
|
||||
.collect_vec()
|
||||
}
|
||||
|
||||
/// Returns the fixed point scale of the computational graph's outputs
|
||||
pub fn get_output_scales(&self) -> Vec<u32> {
|
||||
let output_nodes = self.outputs.iter();
|
||||
@@ -280,8 +289,8 @@ impl Model {
|
||||
/// * `reader` - A reader for an Onnx file.
|
||||
/// * `model_inputs` - A vector of [Tensor]s to use as inputs to the model.
|
||||
/// * `run_args` - [RunArgs]
|
||||
pub fn forward(&self, model_inputs: &[Tensor<i128>]) -> Result<ForwardResult, Box<dyn Error>> {
|
||||
let mut results: BTreeMap<&usize, Tensor<i128>> = BTreeMap::new();
|
||||
pub fn forward(&self, model_inputs: &[Tensor<Fp>]) -> Result<ForwardResult, Box<dyn Error>> {
|
||||
let mut results: BTreeMap<&usize, Tensor<Fp>> = BTreeMap::new();
|
||||
let mut max_lookup_inputs = 0;
|
||||
let mut input_idx = 0;
|
||||
for (idx, n) in self.graph.nodes.iter() {
|
||||
@@ -305,7 +314,7 @@ impl Model {
|
||||
if !n.required_lookups().is_empty() {
|
||||
let mut max = 0;
|
||||
for i in &inputs {
|
||||
max = max.max(i.iter().map(|x| x.abs()).max().unwrap());
|
||||
max = max.max(i.iter().map(|x| felt_to_i128(*x).abs()).max().unwrap());
|
||||
}
|
||||
max_lookup_inputs = max_lookup_inputs.max(max);
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ use crate::circuit::modules::elgamal::{ElGamalConfig, ElGamalGadget, ElGamalVari
|
||||
use crate::circuit::modules::poseidon::spec::{PoseidonSpec, POSEIDON_RATE, POSEIDON_WIDTH};
|
||||
use crate::circuit::modules::poseidon::{PoseidonChip, PoseidonConfig};
|
||||
use crate::circuit::modules::Module;
|
||||
use crate::fieldutils::i128_to_felt;
|
||||
use crate::tensor::{Tensor, ValTensor, ValType};
|
||||
use halo2_proofs::circuit::{Layouter, Value};
|
||||
use halo2_proofs::plonk::{ConstraintSystem, Error};
|
||||
@@ -320,7 +319,6 @@ impl GraphModules {
|
||||
instance_offset: &mut [usize],
|
||||
) -> Result<(), Error> {
|
||||
// reserve module 0 for ... modules
|
||||
|
||||
values.iter_mut().for_each(|x| {
|
||||
// hash the input and replace the constrained cells in the input
|
||||
let cloned_x = (*x).clone();
|
||||
@@ -414,7 +412,7 @@ impl GraphModules {
|
||||
|
||||
/// Run forward pass
|
||||
pub fn forward(
|
||||
inputs: &[Tensor<i128>],
|
||||
inputs: &[Tensor<Fp>],
|
||||
element_visibility: Visibility,
|
||||
) -> Result<ModuleForwardResult, Box<dyn std::error::Error>> {
|
||||
let mut rng = &mut rand::thread_rng();
|
||||
@@ -423,8 +421,7 @@ impl GraphModules {
|
||||
|
||||
if element_visibility.is_hashed() {
|
||||
let field_elements = inputs.iter().fold(vec![], |mut acc, x| {
|
||||
let field_elements = x.iter().map(|x| i128_to_felt::<Fp>(*x)).collect();
|
||||
let res = ModulePoseidon::run(field_elements).unwrap()[0].clone();
|
||||
let res = ModulePoseidon::run(x.to_vec()).unwrap()[0].clone();
|
||||
acc.extend(res);
|
||||
acc
|
||||
});
|
||||
@@ -435,9 +432,7 @@ impl GraphModules {
|
||||
let variables = ElGamalVariables::gen_random(&mut rng);
|
||||
|
||||
let elgamal_outputs = inputs.iter().fold(vec![], |mut acc: Vec<Vec<Fp>>, x| {
|
||||
let field_elements = x.iter().map(|x| i128_to_felt::<Fp>(*x)).collect();
|
||||
let ciphers = ElGamalGadget::run((field_elements, variables.clone())).unwrap();
|
||||
|
||||
let ciphers = ElGamalGadget::run((x.to_vec(), variables.clone())).unwrap();
|
||||
if acc.is_empty() {
|
||||
ciphers
|
||||
} else {
|
||||
|
||||
@@ -35,9 +35,9 @@ use tract_onnx::tract_hir::{
|
||||
/// * `dims` - the dimensionality of the resulting [Tensor].
|
||||
/// * `shift` - offset used in the fixed point representation.
|
||||
/// * `scale` - `2^scale` used in the fixed point representation.
|
||||
pub fn quantize_float(elem: &f32, shift: f32, scale: u32) -> Result<i128, TensorError> {
|
||||
let mult = scale_to_multiplier(scale) as f32;
|
||||
let max_value = ((i128::MAX as f32 - shift) / mult).round(); // the maximum value that can be represented w/o sig bit truncation
|
||||
pub fn quantize_float(elem: &f64, shift: f64, scale: u32) -> Result<i128, TensorError> {
|
||||
let mult = scale_to_multiplier(scale);
|
||||
let max_value = ((i128::MAX as f64 - shift) / mult).round(); // the maximum value that can be represented w/o sig bit truncation
|
||||
|
||||
if *elem > max_value {
|
||||
return Err(TensorError::SigBitTruncationError);
|
||||
@@ -774,14 +774,16 @@ pub fn tensor_to_valtensor<F: PrimeField + TensorType + PartialOrd>(
|
||||
Visibility::Public => const_value
|
||||
.map(|x| {
|
||||
crate::tensor::ValType::Constant(crate::fieldutils::i128_to_felt::<F>(
|
||||
quantize_float(&x, 0.0, scale).unwrap(),
|
||||
quantize_float(&x.into(), 0.0, scale).unwrap(),
|
||||
))
|
||||
})
|
||||
.into(),
|
||||
Visibility::Private | Visibility::Hashed | Visibility::Encrypted => const_value
|
||||
.map(|x| {
|
||||
crate::tensor::ValType::Value(halo2_proofs::circuit::Value::known(
|
||||
crate::fieldutils::i128_to_felt::<F>(quantize_float(&x, 0.0, scale).unwrap()),
|
||||
crate::fieldutils::i128_to_felt::<F>(
|
||||
quantize_float(&x.into(), 0.0, scale).unwrap(),
|
||||
),
|
||||
))
|
||||
})
|
||||
.into(),
|
||||
|
||||
@@ -195,8 +195,39 @@ impl<F: PrimeField + TensorType + PartialOrd> ValTensor<F> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Fetch the underlying [Tensor] of field elements.
|
||||
pub fn get_felt_evals(&self) -> Result<Tensor<F>, Box<dyn Error>> {
|
||||
let mut felt_evals: Vec<F> = vec![];
|
||||
match self {
|
||||
ValTensor::Value {
|
||||
inner: v, dims: _, ..
|
||||
} => {
|
||||
// we have to push to an externally created vector or else vaf.map() returns an evaluation wrapped in Value<> (which we don't want)
|
||||
let _ = v.map(|vaf| match vaf {
|
||||
ValType::Value(v) => v.map(|f| {
|
||||
felt_evals.push(f);
|
||||
}),
|
||||
ValType::AssignedValue(v) => v.map(|f| {
|
||||
felt_evals.push(f.evaluate());
|
||||
}),
|
||||
ValType::PrevAssigned(v) => v.value_field().map(|f| {
|
||||
felt_evals.push(f.evaluate());
|
||||
}),
|
||||
ValType::Constant(v) => {
|
||||
felt_evals.push(v);
|
||||
Value::unknown()
|
||||
}
|
||||
});
|
||||
}
|
||||
_ => return Err(Box::new(TensorError::WrongMethod)),
|
||||
};
|
||||
|
||||
let res: Tensor<F> = felt_evals.into_iter().into();
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
/// Calls `int_evals` on the inner tensor.
|
||||
pub fn get_int_evals(&self) -> Result<Vec<i128>, Box<dyn Error>> {
|
||||
pub fn get_int_evals(&self) -> Result<Tensor<i128>, Box<dyn Error>> {
|
||||
// finally convert to vector of integers
|
||||
let mut integer_evals: Vec<i128> = vec![];
|
||||
match self {
|
||||
@@ -222,7 +253,7 @@ impl<F: PrimeField + TensorType + PartialOrd> ValTensor<F> {
|
||||
}
|
||||
_ => return Err(Box::new(TensorError::WrongMethod)),
|
||||
};
|
||||
Ok(integer_evals)
|
||||
Ok(integer_evals.into_iter().into())
|
||||
}
|
||||
|
||||
/// Calls `get_slice` on the inner tensor.
|
||||
|
||||
@@ -174,8 +174,9 @@ impl VarTensor {
|
||||
offset: usize,
|
||||
constant: F
|
||||
) -> Result<AssignedCell<F, F>, halo2_proofs::plonk::Error>{
|
||||
|
||||
let (x, y) = self.cartesian_coord(offset);
|
||||
|
||||
|
||||
match &self {
|
||||
VarTensor::Advice { inner: advices, .. } => {
|
||||
region.assign_advice_from_constant(|| "constant", advices[x], y, constant)
|
||||
@@ -184,10 +185,10 @@ impl VarTensor {
|
||||
region.assign_fixed(|| "constant", fixed[x], y, || Value::known(constant))
|
||||
}
|
||||
_ => panic!()
|
||||
|
||||
|
||||
}}
|
||||
|
||||
|
||||
|
||||
/// Assigns [ValTensor] to the columns of the inner tensor.
|
||||
pub fn assign<F: PrimeField + TensorType + PartialOrd>(
|
||||
&self,
|
||||
|
||||
@@ -209,7 +209,7 @@ pub fn prove_wasm(
|
||||
.unwrap();
|
||||
|
||||
// prep public inputs
|
||||
circuit.load_data(&data).unwrap();
|
||||
circuit.load_graph_witness(&data).unwrap();
|
||||
let public_inputs = circuit.prepare_public_inputs(&data).unwrap();
|
||||
|
||||
let strategy = KZGSingleStrategy::new(¶ms);
|
||||
|
||||
@@ -3,8 +3,8 @@
|
||||
mod native_tests {
|
||||
|
||||
use core::panic;
|
||||
use ezkl_lib::graph::input::GraphInput;
|
||||
use ezkl_lib::graph::DataSource;
|
||||
use ezkl_lib::graph::GraphWitness;
|
||||
use lazy_static::lazy_static;
|
||||
use std::env::var;
|
||||
use std::process::Command;
|
||||
@@ -126,7 +126,7 @@ mod native_tests {
|
||||
|
||||
assert!(status.success());
|
||||
|
||||
let data = GraphWitness::from_path(format!("{}/{}/input.json", test_dir, test).into())
|
||||
let data = GraphInput::from_path(format!("{}/{}/input.json", test_dir, test).into())
|
||||
.expect("failed to load input data");
|
||||
|
||||
let input_data = match data.input_data {
|
||||
@@ -134,25 +134,12 @@ mod native_tests {
|
||||
DataSource::OnChain(_) => panic!("Only File data sources support batching"),
|
||||
};
|
||||
|
||||
let output_data = match data.output_data {
|
||||
DataSource::File(data) => data,
|
||||
DataSource::OnChain(_) => panic!("Only File data sources support batching"),
|
||||
};
|
||||
|
||||
let duplicated_input_data: Vec<Vec<f32>> = input_data
|
||||
let duplicated_input_data: Vec<Vec<f64>> = input_data
|
||||
.iter()
|
||||
.map(|data| (0..num_batches).flat_map(|_| data.clone()).collect())
|
||||
.collect();
|
||||
|
||||
let duplicated_output_data: Vec<Vec<f32>> = output_data
|
||||
.iter()
|
||||
.map(|data| (0..num_batches).flat_map(|_| data.clone()).collect())
|
||||
.collect();
|
||||
|
||||
let duplicated_data = GraphWitness::new(
|
||||
DataSource::File(duplicated_input_data),
|
||||
DataSource::File(duplicated_output_data),
|
||||
);
|
||||
let duplicated_data = GraphInput::new(DataSource::File(duplicated_input_data));
|
||||
|
||||
let res =
|
||||
duplicated_data.save(format!("{}/{}/input.json", test_dir, output_dir).into());
|
||||
|
||||
@@ -153,10 +153,8 @@ def test_forward():
|
||||
with open(output_path, "r") as f:
|
||||
data = json.load(f)
|
||||
|
||||
assert data["input_data"] == res["input_data"] == [[0.05326242372393608, 0.07497056573629379, 0.05235547572374344, 0.028825461864471436, 0.05848702788352966,
|
||||
0.008225822821259499, 0.07530029118061066, 0.0821458026766777, 0.06227986887097359, 0.024306034669280052, 0.05793173983693123, 0.040442030876874924]]
|
||||
assert data["output_data"] == res["output_data"] == [[0.05322265625, 0.12841796875, 0.0751953125, 0.10546875, 0.20947265625, 0.10400390625, 0.05224609375, 0.0810546875, 0.02880859375, 0.05859375, 0.06689453125,
|
||||
0.00830078125, 0.1337890625, 0.22412109375, 0.09033203125, 0.0751953125, 0.1572265625, 0.08203125, 0.0625, 0.0869140625, 0.0244140625, 0.12060546875, 0.185546875, 0.06494140625, 0.05810546875, 0.0986328125, 0.04052734375]]
|
||||
assert data["input_data"] == res["input_data"]
|
||||
assert data["output_data"] == res["output_data"]
|
||||
|
||||
assert data["processed_inputs"]["poseidon_hash"] == res["processed_inputs"]["poseidon_hash"] == [[
|
||||
8270957937025516140, 11801026918842104328, 2203849898884507041, 140307258138425306]]
|
||||
@@ -452,8 +450,16 @@ async def aggregate_and_verify_aggr():
|
||||
|
||||
proof_path = os.path.join(folder_path, '1l_relu.pf')
|
||||
|
||||
output_path = os.path.join(
|
||||
folder_path,
|
||||
'1l_relu_aggr_witness.json'
|
||||
)
|
||||
|
||||
res = ezkl_lib.gen_witness(data_path, model_path,
|
||||
output_path, settings_path=settings_path)
|
||||
|
||||
ezkl_lib.prove(
|
||||
data_path,
|
||||
output_path,
|
||||
model_path,
|
||||
pk_path,
|
||||
proof_path,
|
||||
@@ -540,8 +546,16 @@ async def evm_aggregate_and_verify_aggr():
|
||||
|
||||
proof_path = os.path.join(folder_path, '1l_relu.pf')
|
||||
|
||||
output_path = os.path.join(
|
||||
folder_path,
|
||||
'1l_relu_aggr_evm_witness.json'
|
||||
)
|
||||
|
||||
res = ezkl_lib.gen_witness(data_path, model_path,
|
||||
output_path, settings_path=settings_path)
|
||||
|
||||
ezkl_lib.prove(
|
||||
data_path,
|
||||
output_path,
|
||||
model_path,
|
||||
pk_path,
|
||||
proof_path,
|
||||
|
||||
@@ -1,407 +1,52 @@
|
||||
{
|
||||
"input_data": [
|
||||
[
|
||||
1.5417295,
|
||||
0.5346153,
|
||||
1.2172532
|
||||
[
|
||||
6425625360762666998,
|
||||
7924344314350639699,
|
||||
14762033076929465436,
|
||||
2023505479389396574
|
||||
],
|
||||
[
|
||||
12436184717236109307,
|
||||
3962172157175319849,
|
||||
7381016538464732718,
|
||||
1011752739694698287
|
||||
],
|
||||
[
|
||||
12436184717236109307,
|
||||
3962172157175319849,
|
||||
7381016538464732718,
|
||||
1011752739694698287
|
||||
]
|
||||
]
|
||||
],
|
||||
"output_data": [
|
||||
[
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0
|
||||
[
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0
|
||||
],
|
||||
[
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0
|
||||
],
|
||||
[
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0
|
||||
],
|
||||
[
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0
|
||||
]
|
||||
]
|
||||
],
|
||||
"processed_inputs": {
|
||||
"poseidon_hash": [
|
||||
[
|
||||
1205357537771999130,
|
||||
5984390958877082714,
|
||||
9820729522268186501,
|
||||
2234099058051285368
|
||||
]
|
||||
],
|
||||
"elgmal_results": {
|
||||
"variables": {
|
||||
"r": [
|
||||
8241161763432706976,
|
||||
666255099567061377,
|
||||
14805312420564963680,
|
||||
1204508807060214874
|
||||
],
|
||||
"pk": {
|
||||
"x": [
|
||||
15190346206350482033,
|
||||
5345466492465898205,
|
||||
15673625536425086517,
|
||||
2946938976802915929
|
||||
],
|
||||
"y": [
|
||||
3849689718459676759,
|
||||
2554904641803052543,
|
||||
17914557513841193544,
|
||||
2568907273734445593
|
||||
]
|
||||
},
|
||||
"sk": [
|
||||
11176153975878706009,
|
||||
4400691958827003321,
|
||||
2763850695090718903,
|
||||
2596642706525330605
|
||||
],
|
||||
"window_size": 1,
|
||||
"aux_generator": {
|
||||
"x": [
|
||||
10023487266988000827,
|
||||
5757424535551704030,
|
||||
14065811819017675132,
|
||||
708744013093495848
|
||||
],
|
||||
"y": [
|
||||
3035630535707533851,
|
||||
15084156752173132954,
|
||||
4896068836772202981,
|
||||
3481542306236198361
|
||||
]
|
||||
}
|
||||
},
|
||||
"ciphertexts": [
|
||||
[],
|
||||
[
|
||||
[
|
||||
17425460071899514968,
|
||||
1167359589260142187,
|
||||
6756444933589811955,
|
||||
2891408151901966517
|
||||
],
|
||||
[
|
||||
829458805722111377,
|
||||
17250282503078918728,
|
||||
7016627888695675713,
|
||||
2925376423043353791
|
||||
]
|
||||
],
|
||||
[
|
||||
[
|
||||
7811839396291636301,
|
||||
16626502080672679222,
|
||||
7596684602016322355,
|
||||
1171033058875794198
|
||||
],
|
||||
[
|
||||
13822398752765078610,
|
||||
12664329923497359372,
|
||||
215668063551589637,
|
||||
159280319181095911
|
||||
],
|
||||
[
|
||||
13822398752765078610,
|
||||
12664329923497359372,
|
||||
215668063551589637,
|
||||
159280319181095911
|
||||
]
|
||||
],
|
||||
[
|
||||
[
|
||||
9283435499189890162,
|
||||
12255697202589512688,
|
||||
8119465263255675837,
|
||||
3239333981286649072
|
||||
]
|
||||
]
|
||||
]
|
||||
}
|
||||
},
|
||||
"processed_params": {
|
||||
"poseidon_hash": [
|
||||
[
|
||||
2438338217603384681,
|
||||
4147850557857422208,
|
||||
14984698365426657752,
|
||||
2960724354489209719
|
||||
]
|
||||
],
|
||||
"elgamal": {
|
||||
"variables": {
|
||||
"r": [
|
||||
1200181915139776053,
|
||||
13180175622696807913,
|
||||
14476695540645645352,
|
||||
1345741327868623972
|
||||
],
|
||||
"pk": {
|
||||
"x": [
|
||||
9797896387342524947,
|
||||
15035349603715646648,
|
||||
13792633755538567363,
|
||||
1133541516276092058
|
||||
],
|
||||
"y": [
|
||||
9476212339347896300,
|
||||
11232279511380632903,
|
||||
13562657712233073770,
|
||||
2203529712687503461
|
||||
]
|
||||
},
|
||||
"sk": [
|
||||
15504567943908784228,
|
||||
16286748745616108518,
|
||||
11666349446092175813,
|
||||
739552556110375646
|
||||
],
|
||||
"window_size": 1,
|
||||
"aux_generator": {
|
||||
"x": [
|
||||
3146347569958584183,
|
||||
5994772996379248674,
|
||||
12027192890968946425,
|
||||
250337585815942863
|
||||
],
|
||||
"y": [
|
||||
15759508736901521891,
|
||||
13730739491295063583,
|
||||
12178033622663772459,
|
||||
328932231313987632
|
||||
]
|
||||
}
|
||||
},
|
||||
"ciphertexts": [
|
||||
[],
|
||||
[
|
||||
[
|
||||
3687758631139381896,
|
||||
16990020959593879569,
|
||||
11442997537587059238,
|
||||
2601192457456256832
|
||||
],
|
||||
[
|
||||
9698149576377935801,
|
||||
11180115031102909471,
|
||||
10273504557556343240,
|
||||
1717383746578603920
|
||||
]
|
||||
],
|
||||
[
|
||||
[
|
||||
11464092982624561027,
|
||||
14296265120971431549,
|
||||
9872837265046557632,
|
||||
798828383313516575
|
||||
],
|
||||
[
|
||||
3919368951425050505,
|
||||
13231007347102958053,
|
||||
15773012677856519663,
|
||||
3274073910421788953
|
||||
],
|
||||
[
|
||||
11464092982624561027,
|
||||
14296265120971431549,
|
||||
9872837265046557632,
|
||||
798828383313516575
|
||||
],
|
||||
[
|
||||
11464092982624561027,
|
||||
14296265120971431549,
|
||||
9872837265046557632,
|
||||
798828383313516575
|
||||
],
|
||||
[
|
||||
3919368951425050505,
|
||||
13231007347102958053,
|
||||
15773012677856519663,
|
||||
3274073910421788953
|
||||
],
|
||||
[
|
||||
11464092982624561027,
|
||||
14296265120971431549,
|
||||
9872837265046557632,
|
||||
798828383313516575
|
||||
],
|
||||
[
|
||||
11464092982624561027,
|
||||
14296265120971431549,
|
||||
9872837265046557632,
|
||||
798828383313516575
|
||||
],
|
||||
[
|
||||
11464092982624561027,
|
||||
14296265120971431549,
|
||||
9872837265046557632,
|
||||
798828383313516575
|
||||
],
|
||||
[
|
||||
3919368951425050505,
|
||||
13231007347102958053,
|
||||
15773012677856519663,
|
||||
3274073910421788953
|
||||
],
|
||||
[
|
||||
11464092982624561027,
|
||||
14296265120971431549,
|
||||
9872837265046557632,
|
||||
798828383313516575
|
||||
],
|
||||
[
|
||||
11464092982624561027,
|
||||
14296265120971431549,
|
||||
9872837265046557632,
|
||||
798828383313516575
|
||||
],
|
||||
[
|
||||
11464092982624561027,
|
||||
14296265120971431549,
|
||||
9872837265046557632,
|
||||
798828383313516575
|
||||
],
|
||||
[
|
||||
11464092982624561027,
|
||||
14296265120971431549,
|
||||
9872837265046557632,
|
||||
798828383313516575
|
||||
],
|
||||
[
|
||||
5453533626151118718,
|
||||
18258437278146751399,
|
||||
17253853803511290350,
|
||||
1810581123008214862
|
||||
],
|
||||
[
|
||||
11464092982624561027,
|
||||
14296265120971431549,
|
||||
9872837265046557632,
|
||||
798828383313516575
|
||||
],
|
||||
[
|
||||
11464092982624561027,
|
||||
14296265120971431549,
|
||||
9872837265046557632,
|
||||
798828383313516575
|
||||
],
|
||||
[
|
||||
11464092982624561027,
|
||||
14296265120971431549,
|
||||
9872837265046557632,
|
||||
798828383313516575
|
||||
]
|
||||
],
|
||||
[
|
||||
[
|
||||
11835554726384377649,
|
||||
3476093169659441207,
|
||||
10317446023932452752,
|
||||
1341378657344074082
|
||||
]
|
||||
]
|
||||
]
|
||||
}
|
||||
},
|
||||
"processed_outputs": {
|
||||
"poseidon_hash": [
|
||||
[
|
||||
13525436761904490906,
|
||||
5108452989262774252,
|
||||
17113314644501751557,
|
||||
3448479653599714018
|
||||
]
|
||||
],
|
||||
"elgmal_results": {
|
||||
"variables": {
|
||||
"r": [
|
||||
4178821741174955177,
|
||||
14671768628465678513,
|
||||
4455016481844455948,
|
||||
596401571849164806
|
||||
],
|
||||
"pk": {
|
||||
"x": [
|
||||
17993908268456464107,
|
||||
12767282455815797736,
|
||||
6841806501674431080,
|
||||
147587434280887521
|
||||
],
|
||||
"y": [
|
||||
16992425197354392191,
|
||||
9664339530271872349,
|
||||
1076902239844679755,
|
||||
2685162918446917537
|
||||
]
|
||||
},
|
||||
"sk": [
|
||||
1307030728222113769,
|
||||
11176383346436204449,
|
||||
2128349950767135257,
|
||||
1684071711526912194
|
||||
],
|
||||
"window_size": 1,
|
||||
"aux_generator": {
|
||||
"x": [
|
||||
4657277071614079755,
|
||||
1369367041476422289,
|
||||
14912396191352368123,
|
||||
3115177155409298233
|
||||
],
|
||||
"y": [
|
||||
11259391534341870480,
|
||||
12985819015883975733,
|
||||
5695317503694964668,
|
||||
1609774512612003893
|
||||
]
|
||||
}
|
||||
},
|
||||
"ciphertexts": [
|
||||
[],
|
||||
[
|
||||
[
|
||||
310776321632342963,
|
||||
8782220497026581904,
|
||||
967332862939820787,
|
||||
2438830089048710677
|
||||
],
|
||||
[
|
||||
17188119484612178357,
|
||||
5607787842121139250,
|
||||
12990365590313846177,
|
||||
2940512950316690900
|
||||
]
|
||||
],
|
||||
[
|
||||
[
|
||||
11468406689166593318,
|
||||
9609755014076369146,
|
||||
782845172795013879,
|
||||
2433980276123347322
|
||||
],
|
||||
[
|
||||
11468406689166593318,
|
||||
9609755014076369146,
|
||||
782845172795013879,
|
||||
2433980276123347322
|
||||
],
|
||||
[
|
||||
11468406689166593318,
|
||||
9609755014076369146,
|
||||
782845172795013879,
|
||||
2433980276123347322
|
||||
],
|
||||
[
|
||||
11468406689166593318,
|
||||
9609755014076369146,
|
||||
782845172795013879,
|
||||
2433980276123347322
|
||||
]
|
||||
],
|
||||
[
|
||||
[
|
||||
15812472742830711787,
|
||||
11777278746141567566,
|
||||
4910275528091806968,
|
||||
1659402969803748057
|
||||
]
|
||||
]
|
||||
]
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
Reference in New Issue
Block a user