refactor: split input and witness data types (#334)

This commit is contained in:
dante
2023-06-30 03:55:27 +01:00
committed by GitHub
parent 2180e0f19c
commit 4e350ec06c
19 changed files with 459 additions and 661 deletions

View File

@@ -2,6 +2,7 @@ use std::any::Any;
use crate::{
circuit::{self, layouts, Tolerance},
fieldutils::{felt_to_i128, i128_to_felt},
graph::scale_to_multiplier,
tensor::{self, Tensor, TensorError, TensorType, ValTensor},
};
@@ -38,27 +39,32 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
}
/// Matches a [Op] to an operation in the `tensor::ops` module.
fn f(&self, inputs: &[Tensor<i128>]) -> Result<ForwardResult, TensorError> {
fn f(&self, inputs: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
let x = inputs[0].clone().map(|x| felt_to_i128(x));
let (res, intermediate_lookups) = match &self {
HybridOp::Max { axes, .. } => (tensor::ops::max_axes(&inputs[0], axes)?, vec![]),
HybridOp::Max { axes, .. } => (tensor::ops::max_axes(&x, axes)?, vec![]),
HybridOp::MaxPool2d {
padding,
stride,
pool_dims,
..
} => (
tensor::ops::max_pool2d(&inputs[0], padding, stride, pool_dims)?,
tensor::ops::max_pool2d(&x, padding, stride, pool_dims)?,
vec![],
),
HybridOp::Min { axes, .. } => (tensor::ops::min_axes(&inputs[0], axes)?, vec![]),
HybridOp::Min { axes, .. } => (tensor::ops::min_axes(&x, axes)?, vec![]),
HybridOp::Softmax { scales } => {
tensor::ops::nonlinearities::multi_dim_softmax(&inputs[0], scales.0, scales.1)
tensor::ops::nonlinearities::multi_dim_softmax(&x, scales.0, scales.1)
}
HybridOp::RangeCheck(..) => (inputs[0].clone(), vec![]),
HybridOp::RangeCheck(..) => (x.clone(), vec![]),
};
// convert back to felt
let output = res.map(|x| i128_to_felt(x));
Ok(ForwardResult {
output: res,
output,
intermediate_lookups,
})
}

View File

@@ -1396,23 +1396,19 @@ pub fn nonlinearity<F: PrimeField + TensorType + PartialOrd>(
let w = region.assign(&config.lookup_input, x)?;
// extract integer_valuations
let integer_evals: Tensor<i128> = w
.get_int_evals()
.map_err(|e| {
error!("{}", e);
halo2_proofs::plonk::Error::Synthesis
})?
.into_iter()
.into();
let felt_evals: Tensor<F> = w.get_felt_evals().map_err(|e| {
error!("{}", e);
halo2_proofs::plonk::Error::Synthesis
})?;
// for key generation integer_evals will be empty and we need to return a set of unassigned values
let output: Tensor<Value<F>> = match integer_evals.len() {
let output: Tensor<Value<F>> = match felt_evals.len() {
// if empty return an unknown val
0 => Tensor::from((0..x.dims().iter().product::<usize>()).map(|_| Value::unknown())),
// if not empty apply the nonlinearity !
_ => {
let x = Op::<F>::f(nl, &[integer_evals])?;
x.output.map(|elem| Value::known(i128_to_felt(elem)))
let x = Op::<F>::f(nl, &[felt_evals])?;
x.output.map(|elem| Value::known(elem))
}
};

View File

@@ -4,7 +4,7 @@ use std::error::Error;
use crate::{
circuit::{layouts, utils},
fieldutils::i128_to_felt,
fieldutils::{felt_to_i128, i128_to_felt},
graph::scale_to_multiplier,
tensor::{self, Tensor, TensorError, TensorType},
};
@@ -32,10 +32,10 @@ pub enum LookupOp {
impl LookupOp {
/// a value which is always in the table
pub fn default_pair<F: PrimeField + TensorType + PartialOrd>(&self) -> (F, F) {
let x = vec![0_i128].into_iter().into();
let x = vec![i128_to_felt(0_i128)].into_iter().into();
(
<F as TensorType>::zero().unwrap(),
i128_to_felt(Op::<F>::f(self, &[x]).unwrap().output[0]),
Op::<F>::f(self, &[x]).unwrap().output[0],
)
}
}
@@ -45,50 +45,51 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for LookupOp {
self
}
/// Matches a [Op] to an operation in the `tensor::ops` module.
fn f(&self, x: &[Tensor<i128>]) -> Result<ForwardResult, TensorError> {
fn f(&self, x: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
let x = x[0].clone().map(|x| felt_to_i128(x));
let res = match &self {
LookupOp::GreaterThan { a } => Ok(tensor::ops::nonlinearities::greater_than(
&x[0],
&x,
f32::from(*a).into(),
)),
LookupOp::Div { denom } => Ok(tensor::ops::nonlinearities::const_div(
&x[0],
&x,
f32::from(*denom).into(),
)),
LookupOp::Recip { scale } => {
Ok(tensor::ops::nonlinearities::recip(&x[0], *scale as u32))
}
LookupOp::Recip { scale } => Ok(tensor::ops::nonlinearities::recip(&x, *scale as u32)),
LookupOp::ReLU { scale } => {
Ok(tensor::ops::nonlinearities::leakyrelu(&x[0], *scale, 0_f64))
Ok(tensor::ops::nonlinearities::leakyrelu(&x, *scale, 0_f64))
}
LookupOp::LeakyReLU { scale, slope } => Ok(tensor::ops::nonlinearities::leakyrelu(
&x[0],
&x,
*scale,
slope.0.into(),
)),
LookupOp::Sigmoid { scales } => Ok(tensor::ops::nonlinearities::sigmoid(
&x[0], scales.0, scales.1,
)),
LookupOp::Sigmoid { scales } => {
Ok(tensor::ops::nonlinearities::sigmoid(&x, scales.0, scales.1))
}
LookupOp::Sqrt { scales } => {
Ok(tensor::ops::nonlinearities::sqrt(&x[0], scales.0, scales.1))
Ok(tensor::ops::nonlinearities::sqrt(&x, scales.0, scales.1))
}
LookupOp::Rsqrt { scales } => {
Ok(tensor::ops::nonlinearities::rsqrt(&x, scales.0, scales.1))
}
LookupOp::Rsqrt { scales } => Ok(tensor::ops::nonlinearities::rsqrt(
&x[0], scales.0, scales.1,
)),
LookupOp::Tanh { scales } => {
Ok(tensor::ops::nonlinearities::tanh(&x[0], scales.0, scales.1))
Ok(tensor::ops::nonlinearities::tanh(&x, scales.0, scales.1))
}
LookupOp::Erf { scales } => {
Ok(tensor::ops::nonlinearities::erffunc(&x, scales.0, scales.1))
}
LookupOp::Erf { scales } => Ok(tensor::ops::nonlinearities::erffunc(
&x[0], scales.0, scales.1,
)),
LookupOp::Exp { scales } => {
Ok(tensor::ops::nonlinearities::exp(&x[0], scales.0, scales.1))
Ok(tensor::ops::nonlinearities::exp(&x, scales.0, scales.1))
}
}?;
let output = res.map(|x| i128_to_felt(x));
Ok(ForwardResult {
output: res,
output,
intermediate_lookups: vec![],
})
}

View File

@@ -3,7 +3,10 @@ use std::{any::Any, error::Error};
use itertools::Itertools;
use serde::{Deserialize, Serialize};
use crate::tensor::{self, Tensor, TensorError, TensorType, ValTensor};
use crate::{
fieldutils::{felt_to_i128, i128_to_felt},
tensor::{self, Tensor, TensorError, TensorType, ValTensor},
};
use halo2curves::ff::PrimeField;
use self::{lookup::LookupOp, region::RegionCtx};
@@ -25,15 +28,15 @@ pub mod region;
/// A struct representing the result of a forward pass.
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub struct ForwardResult {
pub(crate) output: Tensor<i128>,
pub struct ForwardResult<F: PrimeField + TensorType + PartialOrd> {
pub(crate) output: Tensor<F>,
pub(crate) intermediate_lookups: Vec<Tensor<i128>>,
}
/// An enum representing operations that can be represented as constraints in a circuit.
pub trait Op<F: PrimeField + TensorType + PartialOrd>: std::fmt::Debug + Send + Sync + Any {
/// Matches a [Op] to an operation in the `tensor::ops` module.
fn f(&self, x: &[Tensor<i128>]) -> Result<ForwardResult, TensorError>;
fn f(&self, x: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError>;
/// Returns a string representation of the operation.
fn as_string(&self) -> String;
@@ -97,7 +100,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for Input {
self
}
fn f(&self, x: &[Tensor<i128>]) -> Result<ForwardResult, TensorError> {
fn f(&self, x: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
Ok(ForwardResult {
output: x[0].clone(),
intermediate_lookups: vec![],
@@ -143,7 +146,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for Rescaled<F> {
fn as_any(&self) -> &dyn Any {
self
}
fn f(&self, x: &[Tensor<i128>]) -> Result<ForwardResult, TensorError> {
fn f(&self, x: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
if self.scale.len() != x.len() {
return Err(TensorError::DimMismatch("rescaled inputs".to_string()));
}
@@ -151,10 +154,10 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for Rescaled<F> {
let mut rescaled_inputs = vec![];
let inputs = &mut x.to_vec();
for (i, ri) in inputs.iter_mut().enumerate() {
rescaled_inputs.push(tensor::ops::nonlinearities::const_div(
ri,
self.scale[i].1 as f64,
));
let ri = ri.map(|x| felt_to_i128(x));
let res = tensor::ops::nonlinearities::const_div(&ri, self.scale[i].1 as f64);
let output = res.map(|x| i128_to_felt(x));
rescaled_inputs.push(output);
}
Op::<F>::f(&*self.inner, &rescaled_inputs)
}
@@ -218,7 +221,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for Unknown {
fn as_any(&self) -> &dyn Any {
self
}
fn f(&self, _: &[Tensor<i128>]) -> Result<ForwardResult, TensorError> {
fn f(&self, _: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
Err(TensorError::WrongMethod)
}
@@ -265,10 +268,10 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for Constant<F> {
fn as_any(&self) -> &dyn Any {
self
}
fn f(&self, _: &[Tensor<i128>]) -> Result<ForwardResult, TensorError> {
let values = self.quantized_values.clone();
let int_values = values.get_int_evals().unwrap();
let output = Tensor::new(Some(&int_values), values.dims())?;
fn f(&self, _: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
let mut output = self.quantized_values.get_felt_evals().unwrap();
// make sure its the right shape
output.reshape(self.quantized_values.dims());
Ok(ForwardResult {
output,

View File

@@ -98,7 +98,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for PolyOp<F> {
}
/// Matches a [Op] to an operation in the `tensor::ops` module.
fn f(&self, inputs: &[Tensor<i128>]) -> Result<ForwardResult, TensorError> {
fn f(&self, inputs: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
let mut inputs = inputs.to_vec();
let res = match &self {
PolyOp::Resize { scale_factor } => tensor::ops::resize(&inputs[0], scale_factor),
@@ -124,14 +124,14 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for PolyOp<F> {
}
PolyOp::Add { a } => {
if let Some(a) = a {
inputs.push(Tensor::new(Some(&a.get_int_evals().unwrap()), a.dims())?);
inputs.push(Tensor::new(Some(&a.get_felt_evals().unwrap()), a.dims())?);
}
tensor::ops::add(&inputs)
}
PolyOp::Sub => tensor::ops::sub(&inputs),
PolyOp::Mult { a } => {
if let Some(a) = a {
inputs.push(Tensor::new(Some(&a.get_int_evals().unwrap()), a.dims())?);
inputs.push(Tensor::new(Some(&a.get_felt_evals().unwrap()), a.dims())?);
}
tensor::ops::mult(&inputs)
}
@@ -141,9 +141,9 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for PolyOp<F> {
padding,
stride,
} => {
inputs.push(Tensor::new(Some(&a.get_int_evals().unwrap()), a.dims())?);
inputs.push(Tensor::new(Some(&a.get_felt_evals().unwrap()), a.dims())?);
if let Some(b) = bias {
inputs.push(Tensor::new(Some(&b.get_int_evals().unwrap()), b.dims())?);
inputs.push(Tensor::new(Some(&b.get_felt_evals().unwrap()), b.dims())?);
}
tensor::ops::conv(&inputs, *padding, *stride)
}
@@ -154,9 +154,9 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for PolyOp<F> {
output_padding,
stride,
} => {
inputs.push(Tensor::new(Some(&a.get_int_evals().unwrap()), a.dims())?);
inputs.push(Tensor::new(Some(&a.get_felt_evals().unwrap()), a.dims())?);
if let Some(b) = bias {
inputs.push(Tensor::new(Some(&b.get_int_evals().unwrap()), b.dims())?);
inputs.push(Tensor::new(Some(&b.get_felt_evals().unwrap()), b.dims())?);
}
tensor::ops::deconv(&inputs, *padding, *output_padding, *stride)
}
@@ -170,7 +170,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for PolyOp<F> {
return Err(TensorError::DimMismatch("pack inputs".to_string()));
}
tensor::ops::pack(&inputs[0], *base as i128, *scale)
tensor::ops::pack(&inputs[0], F::from(*base as u64), *scale)
}
PolyOp::Pow(u) => {
if 1 != inputs.len() {
@@ -218,9 +218,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for PolyOp<F> {
layouts::resize(config, region, values[..].try_into()?, scale_factor)?
}
PolyOp::Iff => layouts::iff(config, region, values[..].try_into()?)?,
PolyOp::Einsum { equation } => {
layouts::einsum(config, region, &mut values, equation)?
}
PolyOp::Einsum { equation } => layouts::einsum(config, region, &mut values, equation)?,
PolyOp::Gather { dim, index } => {
tensor::ops::gather(&values[0].get_inner_tensor()?, *dim, index)?.into()
}

View File

@@ -59,7 +59,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Table<F> {
let smallest = -base.pow(self.bits as u32 - 1);
let largest = base.pow(self.bits as u32 - 1);
let inputs = Tensor::from(smallest..=largest);
let inputs = Tensor::from(smallest..=largest).map(|x| i128_to_felt(x));
let evals = Op::<F>::f(&self.nonlinearity, &[inputs.clone()])?;
self.is_assigned = true;
@@ -75,14 +75,14 @@ impl<F: PrimeField + TensorType + PartialOrd> Table<F> {
|| format!("nl_i_col row {}", row_offset),
self.table_input,
row_offset,
|| Value::known(i128_to_felt::<F>(*input)),
|| Value::known(*input),
)?;
table.assign_cell(
|| format!("nl_o_col row {}", row_offset),
self.table_output,
row_offset,
|| Value::known(i128_to_felt::<F>(evals.output[row_offset])),
|| Value::known(evals.output[row_offset]),
)?;
Ok(())
})

View File

@@ -1,4 +1,4 @@
use crate::graph::input::{CallsToAccount, DataSource, GraphWitness};
use crate::graph::input::{CallsToAccount, GraphWitness, WitnessSource};
#[cfg(not(target_arch = "wasm32"))]
use crate::graph::GraphSettings;
use crate::pfsys::evm::{DeploymentCode, EvmVerificationError};
@@ -164,12 +164,12 @@ pub async fn deploy_da_verifier_via_solidity(
let mut instance_idx = 0;
let mut contract_instance_offset = 0;
if let DataSource::OnChain(source) = witness.input_data {
if let WitnessSource::OnChain(source) = witness.input_data {
for call in source.calls {
calls_to_accounts.push(call);
instance_idx += 1;
}
} else if let DataSource::File(source) = witness.input_data {
} else if let WitnessSource::File(source) = witness.input_data {
if settings.run_args.input_visibility.is_public() {
instance_idx += source.len();
for s in source {
@@ -178,7 +178,7 @@ pub async fn deploy_da_verifier_via_solidity(
}
}
if let DataSource::OnChain(source) = witness.output_data {
if let WitnessSource::OnChain(source) = witness.output_data {
let output_scales = settings.model_output_scales;
for call in source.calls {
calls_to_accounts.push(call);
@@ -393,7 +393,7 @@ pub fn get_provider(rpc_url: &str) -> Result<Provider<Http>, Box<dyn Error>> {
/// the number of decimals of the floating point value on chain.
pub async fn test_on_chain_data<M: 'static + Middleware>(
client: Arc<M>,
data: &Vec<Vec<f32>>,
data: &[Vec<f32>],
) -> Result<Vec<CallsToAccount>, Box<dyn Error>> {
let (contract, decimals) = setup_test_contract(client.clone(), data).await?;
@@ -454,7 +454,7 @@ pub async fn evm_quantize<M: 'static + Middleware>(
client: Arc<M>,
scales: Vec<f64>,
data: &(Vec<ethers::types::Bytes>, Vec<u8>),
) -> Result<Vec<i128>, Box<dyn Error>> {
) -> Result<Vec<Fr>, Box<dyn Error>> {
// save the sol to a tmp file
let mut sol_path = std::env::temp_dir();
sol_path.push("quantizedata.sol");
@@ -495,7 +495,11 @@ pub async fn evm_quantize<M: 'static + Middleware>(
.call()
.await;
let results = results.unwrap();
let results = results
.unwrap()
.iter()
.map(|x| crate::fieldutils::i128_to_felt(*x))
.collect::<Vec<Fr>>();
info!("evm quantization results: {:#?}", results,);
Ok(results.to_vec())
}
@@ -513,7 +517,7 @@ fn get_sol_contract_factory<M: 'static + Middleware>(
if size > MAX_RUNTIME_BYTECODE_SIZE {
// `_runtime_bytecode` exceeds the limit
panic!(
"Solidity runtime bytecode size is: {:#?},
"Solidity runtime bytecode size is: {:#?},
which exceeds 24577 bytes limit.
Try setting '--optimzer-runs 1' when generating the verifier
so SOLC can optimize for the smallest deployment",

View File

@@ -6,8 +6,8 @@ use crate::commands::{Cli, Commands, RunArgs};
use crate::eth::{deploy_da_verifier_via_solidity, deploy_verifier_via_solidity};
#[cfg(not(target_arch = "wasm32"))]
use crate::eth::{fix_verifier_sol, get_contract_artifacts, verify_proof_via_solidity};
use crate::graph::input::{DataSource, GraphInput};
use crate::graph::{scale_to_multiplier, GraphCircuit, GraphSettings, GraphWitness, Model};
use crate::graph::input::{GraphInput, WitnessSource};
use crate::graph::{GraphCircuit, GraphSettings, GraphWitness, Model};
#[cfg(not(target_arch = "wasm32"))]
use crate::graph::{TestDataSource, TestSources, Visibility};
use crate::pfsys::evm::aggregation::{AggregationCircuit, PoseidonTranscript};
@@ -497,9 +497,9 @@ pub(crate) async fn gen_witness(
let data = GraphInput::from_path(data)?;
#[cfg(not(target_arch = "wasm32"))]
circuit.load_inputs(&data).await?;
circuit.load_graph_input(&data).await?;
#[cfg(target_arch = "wasm32")]
circuit.load_inputs(&data)?;
circuit.load_graph_input(&data)?;
let start_time = Instant::now();
@@ -516,22 +516,22 @@ pub(crate) async fn gen_witness(
res.outputs.iter().map(|t| t.dims()).collect_vec()
);
let output_scales = circuit.model.graph.get_output_scales();
let output_scales = output_scales
let input_witness: Vec<Vec<Fr>> = res
.inputs
.iter()
.map(|scale| scale_to_multiplier(*scale));
.map(|t| t.clone().into_iter().collect_vec())
.collect();
let float_res: Vec<Vec<f32>> = res
let output_witness: Vec<Vec<Fr>> = res
.outputs
.iter()
.zip(output_scales)
.map(|(t, scale)| t.iter().map(|e| ((*e as f64 / scale) as f32)).collect_vec())
.map(|t| t.clone().into_iter().collect_vec())
.collect();
trace!("model forward pass output: {:?}", float_res);
trace!("model forward pass output: {:?}", output_witness);
let witness = GraphWitness {
input_data: data.input_data,
output_data: DataSource::File(float_res),
input_data: WitnessSource::File(input_witness),
output_data: WitnessSource::File(output_witness),
processed_inputs: res.processed_inputs,
processed_params: res.processed_params,
processed_outputs: res.processed_outputs,
@@ -587,7 +587,7 @@ pub(crate) fn init_bar(len: u64) -> ProgressBar {
)
.unwrap()
.progress_chars("##-");
pb.set_style(sty.clone());
pb.set_style(sty);
pb
}
@@ -653,7 +653,7 @@ pub(crate) async fn calibrate(
tokio::task::spawn(async move {
circuit
.load_inputs(&chunk)
.load_graph_input(&chunk)
.await
.map_err(|_| "failed to load circuit inputs")
.unwrap();
@@ -762,9 +762,9 @@ pub(crate) async fn mock(
let data = GraphWitness::from_path(data_path.clone())?;
#[cfg(not(target_arch = "wasm32"))]
circuit.load_data(&data, None).await?;
circuit.load_graph_witness(&data, None).await?;
#[cfg(target_arch = "wasm32")]
circuit.load_data(&data)?;
circuit.load_graph_witness(&data)?;
let public_inputs = circuit.prepare_public_inputs(&data)?;
@@ -901,7 +901,7 @@ pub(crate) fn create_evm_data_attestation_verifier(
let data = GraphWitness::from_path(witness)?;
let output_data = if let DataSource::OnChain(source) = data.output_data {
let output_data = if let WitnessSource::OnChain(source) = data.output_data {
if !visibility.output.is_public() {
todo!("we currently don't support private output data on chain")
}
@@ -914,7 +914,7 @@ pub(crate) fn create_evm_data_attestation_verifier(
None
};
let input_data = if let DataSource::OnChain(source) = data.input_data {
let input_data = if let WitnessSource::OnChain(source) = data.input_data {
if !visibility.input.is_public() {
todo!("we currently don't support private input data on chain")
}
@@ -1096,7 +1096,7 @@ pub(crate) async fn setup_test_evm_witness(
};
circuit
.load_data(&data, Some(test_on_chain_witness))
.load_graph_witness(&data, Some(test_on_chain_witness))
.await?;
Ok(())
@@ -1119,7 +1119,7 @@ pub(crate) async fn prove(
let circuit_settings = GraphSettings::load(&settings_path)?;
let mut circuit = GraphCircuit::from_settings(&circuit_settings, &model_path, check_mode)?;
circuit.load_data(&data, None).await?;
circuit.load_graph_witness(&data, None).await?;
let public_inputs = circuit.prepare_public_inputs(&data)?;
let circuit_settings = circuit.settings.clone();
@@ -1202,7 +1202,7 @@ pub(crate) async fn fuzz(
let pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, GraphCircuit>(&circuit, &params)
.map_err(Box::<dyn Error>::from)?;
circuit.load_data(&data, None).await?;
circuit.load_graph_witness(&data, None).await?;
let public_inputs = circuit.prepare_public_inputs(&data)?;
let strategy = KZGSingleStrategy::new(&params);

View File

@@ -1,3 +1,4 @@
use halo2curves::bn256::Fr as Fp;
#[cfg(feature = "python-bindings")]
use pyo3::prelude::*;
#[cfg(feature = "python-bindings")]
@@ -20,8 +21,10 @@ type Decimals = u8;
type Call = String;
type RPCUrl = String;
/// Inner elements of inputs/outputs coming from a file
pub type FileSourceInner = Vec<Vec<f32>>;
/// Inner elements of inputs coming from a file
pub type FileSourceInner = Vec<Vec<f64>>;
/// Inner elements of witness coming from a witness
pub type WitnessFileSourceInner = Vec<Vec<Fp>>;
/// Inner elements of inputs/outputs coming from on-chain
#[derive(Clone, Debug, Deserialize, Serialize, Default, PartialOrd, PartialEq)]
pub struct OnChainSourceInner {
@@ -42,13 +45,14 @@ impl OnChainSourceInner {
#[cfg(not(target_arch = "wasm32"))]
/// Create dummy local on-chain data to test the OnChain data source
pub async fn test_from_file_data(
data: &FileSourceInner,
data: &WitnessFileSourceInner,
scales: Vec<u32>,
shapes: Vec<Vec<usize>>,
rpc: Option<&str>,
) -> Result<(Vec<Tensor<i128>>, Self), Box<dyn std::error::Error>> {
) -> Result<(Vec<Tensor<Fp>>, Self), Box<dyn std::error::Error>> {
use crate::eth::{evm_quantize, read_on_chain_inputs, test_on_chain_data};
use crate::graph::scale_to_multiplier;
use itertools::Itertools;
use log::debug;
// Set up local anvil instance for reading on-chain data
@@ -56,13 +60,25 @@ impl OnChainSourceInner {
let address = client.address();
let calls_to_accounts = test_on_chain_data(client.clone(), data).await?;
let scales: Vec<f64> = scales.into_iter().map(scale_to_multiplier).collect();
// unquantize data
let float_data = data
.iter()
.zip(scales.iter())
.map(|(t, scale)| {
t.iter()
.map(|e| ((crate::fieldutils::felt_to_i128(*e) as f64 / scale) as f32))
.collect_vec()
})
.collect::<Vec<Vec<f32>>>();
let calls_to_accounts = test_on_chain_data(client.clone(), &float_data).await?;
debug!("Calls to accounts: {:?}", calls_to_accounts);
let inputs = read_on_chain_inputs(client.clone(), address, &calls_to_accounts).await?;
debug!("Inputs: {:?}", inputs);
let mut quantized_evm_inputs = vec![];
let scales: Vec<f64> = scales.into_iter().map(scale_to_multiplier).collect();
let mut prev = 0;
for (idx, i) in data.iter().enumerate() {
@@ -81,9 +97,9 @@ impl OnChainSourceInner {
}
// on-chain data has already been quantized at this point. Just need to reshape it and push into tensor vector
let mut inputs: Vec<Tensor<i128>> = vec![];
let mut inputs: Vec<Tensor<Fp>> = vec![];
for (input, shape) in vec![quantized_evm_inputs].iter().zip(shapes) {
let mut t: Tensor<i128> = input.iter().cloned().collect();
let mut t: Tensor<Fp> = input.iter().cloned().collect();
t.reshape(&shape);
inputs.push(t);
}
@@ -141,26 +157,29 @@ impl From<OnChainSourceInner> for DataSource {
}
/// Enum that defines source of the inputs/outputs to the EZKL model
/// used for f32 to f64 conversion
#[derive(Clone, Debug, Deserialize, Serialize)]
#[derive(Clone, Debug, Serialize, PartialOrd, PartialEq)]
#[serde(untagged)]
enum DataSourceF64 {
pub enum WitnessSource {
/// .json File data source.
File(WitnessFileSourceInner),
/// On-chain data source. The first element is the calls to the account, and the second is the RPC url.
OnChain(OnChainSourceInner),
File(Vec<Vec<f64>>),
}
impl Default for WitnessSource {
fn default() -> Self {
WitnessSource::File(vec![vec![]])
}
}
impl From<DataSource> for DataSourceF64 {
fn from(source: DataSource) -> Self {
match source {
DataSource::File(data) => {
let data = data
.iter()
.map(|v| v.iter().map(|&f| f as f64).collect::<Vec<_>>())
.collect::<Vec<_>>();
DataSourceF64::File(data)
}
DataSource::OnChain(source) => DataSourceF64::OnChain(source),
}
impl From<WitnessFileSourceInner> for WitnessSource {
fn from(data: WitnessFileSourceInner) -> Self {
WitnessSource::File(data)
}
}
impl From<OnChainSourceInner> for WitnessSource {
fn from(data: OnChainSourceInner) -> Self {
WitnessSource::OnChain(data)
}
}
@@ -170,9 +189,9 @@ impl From<DataSource> for DataSourceF64 {
pub struct GraphWitness {
/// Inputs to the model / computational graph (can be empty vectors if inputs are coming from on-chain).
/// TODO: Add retrieve from on-chain functionality
pub input_data: DataSource,
pub input_data: WitnessSource,
/// The expected output of the model (can be empty vectors if outputs are not being constrained).
pub output_data: DataSource,
pub output_data: WitnessSource,
/// Optional hashes of the inputs (can be None if there are no commitments). Wrapped as Option for backwards compatibility
pub processed_inputs: Option<ModuleForwardResult>,
/// Optional hashes of the params (can be None if there are no commitments). Wrapped as Option for backwards compatibility
@@ -183,7 +202,7 @@ pub struct GraphWitness {
impl GraphWitness {
///
pub fn new(input_data: DataSource, output_data: DataSource) -> Self {
pub fn new(input_data: WitnessSource, output_data: WitnessSource) -> Self {
GraphWitness {
input_data,
output_data,
@@ -213,14 +232,6 @@ pub struct GraphInput {
pub input_data: DataSource,
}
impl From<GraphWitness> for GraphInput {
fn from(witness: GraphWitness) -> Self {
GraphInput {
input_data: witness.input_data,
}
}
}
impl GraphInput {
///
pub fn new(input_data: DataSource) -> Self {
@@ -302,11 +313,7 @@ impl GraphInput {
}
#[cfg(feature = "python-bindings")]
use halo2curves::{
bn256::{Fr as Fp, G1Affine},
ff::PrimeField,
serde::SerdeObject,
};
use halo2curves::{bn256::G1Affine, ff::PrimeField, serde::SerdeObject};
#[cfg(feature = "python-bindings")]
/// converts fp into Vec<u64>
@@ -400,6 +407,27 @@ impl ToPyObject for DataSource {
}
}
#[cfg(feature = "python-bindings")]
impl ToPyObject for WitnessSource {
fn to_object(&self, py: Python) -> PyObject {
match self {
WitnessSource::File(data) => {
let field_elem: Vec<Vec<Vec<u64>>> = data
.iter()
.map(|x| x.iter().map(field_to_vecu64).collect())
.collect();
field_elem.to_object(py)
}
WitnessSource::OnChain(source) => {
let dict = PyDict::new(py);
dict.set_item("rpc_url", &source.rpc).unwrap();
dict.set_item("calls_to_accounts", &source.calls).unwrap();
dict.to_object(py)
}
}
}
}
#[cfg(feature = "python-bindings")]
impl ToPyObject for GraphWitness {
fn to_object(&self, py: Python) -> PyObject {
@@ -459,8 +487,7 @@ impl Serialize for GraphInput {
S: Serializer,
{
let mut state = serializer.serialize_struct("GraphInput", 4)?;
let input_data: DataSourceF64 = self.input_data.clone().into();
state.serialize_field("input_data", &input_data)?;
state.serialize_field("input_data", &self.input_data)?;
state.end()
}
}
@@ -488,16 +515,37 @@ impl<'de> Deserialize<'de> for DataSource {
}
}
// !!! ALWAYS USE JSON SERIALIZATION FOR GRAPH INPUT
// UNTAGGED ENUMS WONT WORK :( as highlighted here:
impl<'de> Deserialize<'de> for WitnessSource {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let this_json: Box<serde_json::value::RawValue> = Deserialize::deserialize(deserializer)?;
let first_try: Result<WitnessFileSourceInner, _> = serde_json::from_str(this_json.get());
if let Ok(t) = first_try {
return Ok(WitnessSource::File(t));
}
let second_try: Result<OnChainSourceInner, _> = serde_json::from_str(this_json.get());
if let Ok(t) = second_try {
return Ok(WitnessSource::OnChain(t));
}
Err(serde::de::Error::custom("failed to deserialize DataSource"))
}
}
impl Serialize for GraphWitness {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut state = serializer.serialize_struct("GraphWitness", 4)?;
let input_data: DataSourceF64 = self.input_data.clone().into();
let output_data: DataSourceF64 = self.output_data.clone().into();
state.serialize_field("input_data", &input_data)?;
state.serialize_field("output_data", &output_data)?;
state.serialize_field("input_data", &self.input_data)?;
state.serialize_field("output_data", &self.output_data)?;
if let Some(processed_inputs) = &self.processed_inputs {
state.serialize_field("processed_inputs", &processed_inputs)?;
@@ -540,9 +588,9 @@ mod tests {
// this is for backwards compatibility with the old format
fn test_graph_input_serialization_round_trip() {
let file = GraphInput::new(DataSource::File(vec![vec![
0.053_262_424,
0.074_970_566,
0.052_355_476,
0.05326242372393608,
0.07497056573629379,
0.05235547572374344,
]]));
let serialized = serde_json::to_string(&file).unwrap();
@@ -555,7 +603,6 @@ mod tests {
let graph_input3 = serde_json::from_str::<GraphInput>(JSON)
.map_err(|e| e.to_string())
.unwrap();
println!("{:?}", graph_input3.input_data);
assert_eq!(graph_input3, file);
}
}

View File

@@ -11,11 +11,12 @@ pub mod utilities;
/// Representations of a computational graph's variables.
pub mod vars;
pub use input::{DataSource, GraphWitness};
use halo2_proofs::circuit::Value;
pub use input::{DataSource, GraphWitness, WitnessSource};
#[cfg(not(target_arch = "wasm32"))]
use self::input::OnChainSourceInner;
use self::input::{FileSourceInner, GraphInput};
use self::input::{FileSourceInner, GraphInput, WitnessFileSourceInner};
use crate::circuit::lookup::LookupOp;
use crate::circuit::modules::ModulePlanner;
use crate::circuit::CheckMode;
@@ -24,7 +25,7 @@ use crate::fieldutils::i128_to_felt;
use crate::graph::modules::ModuleInstanceOffset;
use crate::tensor::{Tensor, ValTensor};
use halo2_proofs::{
circuit::{Layouter, Value},
circuit::Layouter,
plonk::{Circuit, ConstraintSystem, Error as PlonkError},
};
use halo2curves::bn256::{self, Fr as Fp};
@@ -99,9 +100,9 @@ const MAX_PUBLIC_SRS: u32 = bn256::Fr::S - 2;
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
pub struct ForwardResult {
/// The inputs of the forward pass
pub inputs: Vec<Tensor<i128>>,
pub inputs: Vec<Tensor<Fp>>,
/// The output of the forward pass
pub outputs: Vec<Tensor<i128>>,
pub outputs: Vec<Tensor<Fp>>,
/// Any hashes of inputs generated during the forward pass
pub processed_inputs: Option<ModuleForwardResult>,
/// Any hashes of params generated during the forward pass
@@ -173,9 +174,9 @@ pub struct GraphCircuit {
/// The model / graph of computations.
pub model: Model,
/// Vector of input tensors to the model / graph of computations.
pub inputs: Vec<Tensor<i128>>,
pub inputs: Vec<Tensor<Fp>>,
/// Vector of input tensors to the model / graph of computations.
pub outputs: Vec<Tensor<i128>>,
pub outputs: Vec<Tensor<Fp>>,
/// The settings of the model / graph of computations.
pub settings: GraphSettings,
/// The settings of the model's modules.
@@ -230,9 +231,9 @@ impl GraphCircuit {
check_mode: CheckMode,
) -> Result<GraphCircuit, Box<dyn std::error::Error>> {
// placeholder dummy inputs - must call prepare_public_inputs to load data afterwards
let mut inputs: Vec<Tensor<i128>> = vec![];
let mut inputs: Vec<Tensor<Fp>> = vec![];
for shape in model.graph.input_shapes() {
let t: Tensor<i128> = Tensor::new(None, &shape).unwrap();
let t: Tensor<Fp> = Tensor::new(None, &shape).unwrap();
inputs.push(t);
}
@@ -277,9 +278,9 @@ impl GraphCircuit {
check_mode: CheckMode,
) -> Result<GraphCircuit, Box<dyn std::error::Error>> {
// placeholder dummy inputs - must call prepare_public_inputs to load data afterwards
let mut inputs: Vec<Tensor<i128>> = vec![];
let mut inputs: Vec<Tensor<Fp>> = vec![];
for shape in model.graph.input_shapes() {
let t: Tensor<i128> = Tensor::new(None, &shape).unwrap();
let t: Tensor<Fp> = Tensor::new(None, &shape).unwrap();
inputs.push(t);
}
@@ -299,9 +300,14 @@ impl GraphCircuit {
#[cfg(target_arch = "wasm32")]
/// load inputs and outputs for the model
pub fn load_data(&mut self, data: &GraphWitness) -> Result<(), Box<dyn std::error::Error>> {
self.load_inputs(&data.clone().into())?;
self.load_outputs(data)?;
pub fn load_graph_witness(
&mut self,
data: &GraphWitness,
) -> Result<(), Box<dyn std::error::Error>> {
self.inputs =
self.process_witness_source(&data.input_data, self.model.graph.input_shapes())?;
self.outputs =
self.process_witness_source(&data.output_data, self.model.graph.output_shapes())?;
// load the module settings
self.module_settings = ModuleSettings::from(data);
@@ -310,7 +316,7 @@ impl GraphCircuit {
#[cfg(not(target_arch = "wasm32"))]
/// load inputs and outputs for the model
pub async fn load_data(
pub async fn load_graph_witness(
&mut self,
data: &GraphWitness,
test_on_chain_data: Option<TestOnChainData>,
@@ -322,8 +328,20 @@ impl GraphCircuit {
self.populate_on_chain_test_data(&mut data, test_path)
.await?;
} else {
self.load_inputs(&data.clone().into()).await?;
self.load_outputs(&data).await?;
self.inputs = self
.process_witness_source(
&data.input_data,
self.model.graph.input_shapes(),
self.model.graph.get_input_scales(),
)
.await?;
self.outputs = self
.process_witness_source(
&data.output_data,
self.model.graph.output_shapes(),
self.model.graph.get_output_scales(),
)
.await?;
}
// load the module settings
@@ -358,11 +376,7 @@ impl GraphCircuit {
let mut pi_inner: Vec<Vec<Fp>> = public_inputs
.iter()
.map(|i| {
i.iter()
.map(|e| i128_to_felt::<Fp>(*e))
.collect::<Vec<Fp>>()
})
.map(|i| i.clone().into_iter().collect::<Vec<Fp>>())
.collect::<Vec<Vec<Fp>>>();
let module_instances =
@@ -377,7 +391,10 @@ impl GraphCircuit {
///
#[cfg(target_arch = "wasm32")]
pub fn load_inputs(&mut self, data: &GraphInput) -> Result<(), Box<dyn std::error::Error>> {
pub fn load_graph_input(
&mut self,
data: &GraphInput,
) -> Result<(), Box<dyn std::error::Error>> {
let shapes = self.model.graph.input_shapes();
let scales = vec![self.settings.run_args.scale; shapes.len()];
self.inputs = self.process_data_source(&data.input_data, shapes, scales)?;
@@ -386,7 +403,7 @@ impl GraphCircuit {
///
#[cfg(not(target_arch = "wasm32"))]
pub async fn load_inputs(
pub async fn load_graph_input(
&mut self,
data: &GraphInput,
) -> Result<(), Box<dyn std::error::Error>> {
@@ -395,29 +412,7 @@ impl GraphCircuit {
self.inputs = self
.process_data_source(&data.input_data, shapes, scales)
.await?;
Ok(())
}
///
#[cfg(target_arch = "wasm32")]
pub fn load_outputs(&mut self, data: &GraphWitness) -> Result<(), Box<dyn std::error::Error>> {
let out_scales = self.model.graph.get_output_scales();
let shapes = self.model.graph.output_shapes();
self.outputs = self.process_data_source(&data.output_data, shapes, out_scales)?;
Ok(())
}
#[cfg(not(target_arch = "wasm32"))]
///
pub async fn load_outputs(
&mut self,
data: &GraphWitness,
) -> Result<(), Box<dyn std::error::Error>> {
let out_scales = self.model.graph.get_output_scales();
let shapes = self.model.graph.output_shapes();
self.outputs = self
.process_data_source(&data.output_data, shapes, out_scales)
.await?;
Ok(())
}
@@ -428,7 +423,7 @@ impl GraphCircuit {
data: &DataSource,
shapes: Vec<Vec<usize>>,
scales: Vec<u32>,
) -> Result<Vec<Tensor<i128>>, Box<dyn std::error::Error>> {
) -> Result<Vec<Tensor<Fp>>, Box<dyn std::error::Error>> {
match &data {
DataSource::OnChain(_) => {
panic!("Cannot use on-chain data source as input for wasm rn.")
@@ -437,6 +432,21 @@ impl GraphCircuit {
}
}
#[cfg(target_arch = "wasm32")]
/// Process the data source for the model
fn process_witness_source(
&mut self,
data: &WitnessSource,
shapes: Vec<Vec<usize>>,
) -> Result<Vec<Tensor<Fp>>, Box<dyn std::error::Error>> {
match &data {
WitnessSource::OnChain(_) => {
panic!("Cannot use on-chain data source as input for wasm rn.")
}
WitnessSource::File(file_data) => self.load_witness_file_data(file_data, &shapes),
}
}
#[cfg(not(target_arch = "wasm32"))]
/// Process the data source for the model
async fn process_data_source(
@@ -444,7 +454,7 @@ impl GraphCircuit {
data: &DataSource,
shapes: Vec<Vec<usize>>,
scales: Vec<u32>,
) -> Result<Vec<Tensor<i128>>, Box<dyn std::error::Error>> {
) -> Result<Vec<Tensor<Fp>>, Box<dyn std::error::Error>> {
match &data {
DataSource::OnChain(source) => {
let mut per_item_scale = vec![];
@@ -458,6 +468,27 @@ impl GraphCircuit {
}
}
#[cfg(not(target_arch = "wasm32"))]
/// Process the data source for the model
async fn process_witness_source(
&mut self,
data: &WitnessSource,
shapes: Vec<Vec<usize>>,
scales: Vec<u32>,
) -> Result<Vec<Tensor<Fp>>, Box<dyn std::error::Error>> {
match &data {
WitnessSource::OnChain(source) => {
let mut per_item_scale = vec![];
for (i, shape) in shapes.iter().enumerate() {
per_item_scale.extend(vec![scales[i]; shape.iter().product::<usize>()]);
}
self.load_on_chain_data(source.clone(), &shapes, per_item_scale)
.await
}
WitnessSource::File(file_data) => self.load_witness_file_data(file_data, &shapes),
}
}
/// Prepare on chain test data
#[cfg(not(target_arch = "wasm32"))]
pub async fn load_on_chain_data(
@@ -465,7 +496,7 @@ impl GraphCircuit {
source: OnChainSourceInner,
shapes: &Vec<Vec<usize>>,
scales: Vec<u32>,
) -> Result<Vec<Tensor<i128>>, Box<dyn std::error::Error>> {
) -> Result<Vec<Tensor<Fp>>, Box<dyn std::error::Error>> {
use crate::eth::{evm_quantize, read_on_chain_inputs, setup_eth_backend};
let (_, client) = setup_eth_backend(Some(&source.rpc)).await?;
let inputs = read_on_chain_inputs(client.clone(), client.address(), &source.calls).await?;
@@ -477,9 +508,9 @@ impl GraphCircuit {
)
.await?;
// on-chain data has already been quantized at this point. Just need to reshape it and push into tensor vector
let mut inputs: Vec<Tensor<i128>> = vec![];
let mut inputs: Vec<Tensor<Fp>> = vec![];
for (input, shape) in vec![quantized_evm_inputs].iter().zip(shapes) {
let mut t: Tensor<i128> = input.iter().cloned().collect();
let mut t: Tensor<Fp> = input.iter().cloned().collect();
t.reshape(shape);
inputs.push(t);
}
@@ -493,16 +524,16 @@ impl GraphCircuit {
file_data: &FileSourceInner,
shapes: &Vec<Vec<usize>>,
scales: Vec<u32>,
) -> Result<Vec<Tensor<i128>>, Box<dyn std::error::Error>> {
) -> Result<Vec<Tensor<Fp>>, Box<dyn std::error::Error>> {
// quantize the supplied data using the provided scale.
let mut data: Vec<Tensor<i128>> = vec![];
let mut data: Vec<Tensor<Fp>> = vec![];
for ((d, shape), scale) in file_data.iter().zip(shapes).zip(scales) {
let t: Vec<i128> = d
let t: Vec<Fp> = d
.par_iter()
.map(|x| quantize_float(x, 0.0, scale).unwrap())
.map(|x| i128_to_felt(quantize_float(x, 0.0, scale).unwrap()))
.collect();
let mut t: Tensor<i128> = t.into_iter().into();
let mut t: Tensor<Fp> = t.into_iter().into();
t.reshape(shape);
data.push(t);
@@ -510,6 +541,22 @@ impl GraphCircuit {
Ok(data)
}
///
pub fn load_witness_file_data(
&mut self,
file_data: &WitnessFileSourceInner,
shapes: &Vec<Vec<usize>>,
) -> Result<Vec<Tensor<Fp>>, Box<dyn std::error::Error>> {
// quantize the supplied data using the provided scale.
let mut data: Vec<Tensor<Fp>> = vec![];
for (d, shape) in file_data.iter().zip(shapes) {
let mut t: Tensor<Fp> = d.clone().into_iter().into();
t.reshape(shape);
data.push(t);
}
Ok(data)
}
/// Calibrate the circuit to the supplied data.
pub fn calibrate(&mut self) -> Result<(), Box<dyn std::error::Error>> {
let res = self.forward()?;
@@ -575,7 +622,7 @@ impl GraphCircuit {
if visibility.params.requires_processing() {
let params = self.model.get_all_consts();
let flattened_params = flatten_valtensors(params)?
.get_int_evals()?
.get_felt_evals()?
.into_iter()
.into();
processed_params = Some(GraphModules::forward(
@@ -638,8 +685,8 @@ impl GraphCircuit {
}
let input_data = match &data.input_data {
DataSource::File(input_data) => input_data,
DataSource::OnChain(_) => panic!(
WitnessSource::File(input_data) => input_data,
WitnessSource::OnChain(_) => panic!(
"Cannot use on-chain data source as input for on-chain test.
Will manually populate on-chain data from file source instead"
),
@@ -647,7 +694,7 @@ impl GraphCircuit {
// Get the flatten length of input_data
let length = input_data.iter().map(|x| x.len()).sum();
let scales = vec![self.settings.run_args.scale; length];
let datam: (Vec<Tensor<i128>>, OnChainSourceInner) =
let datam: (Vec<Tensor<Fp>>, OnChainSourceInner) =
OnChainSourceInner::test_from_file_data(
input_data,
scales,
@@ -658,7 +705,13 @@ impl GraphCircuit {
self.inputs = datam.0;
data.input_data = datam.1.into();
} else {
self.load_inputs(&data.clone().into()).await?;
self.inputs = self
.process_witness_source(
&data.input_data,
self.model.graph.input_shapes(),
self.model.graph.get_input_scales(),
)
.await?;
}
if matches!(
test_on_chain_data.data_sources.output,
@@ -670,13 +723,13 @@ impl GraphCircuit {
}
let output_data = match &data.output_data {
DataSource::File(output_data) => output_data,
DataSource::OnChain(_) => panic!(
WitnessSource::File(output_data) => output_data,
WitnessSource::OnChain(_) => panic!(
"Cannot use on-chain data source as output for on-chain test.
Will manually populate on-chain data from file source instead"
),
};
let datum: (Vec<Tensor<i128>>, OnChainSourceInner) =
let datum: (Vec<Tensor<Fp>>, OnChainSourceInner) =
OnChainSourceInner::test_from_file_data(
output_data,
self.model.graph.get_output_scales(),
@@ -687,7 +740,13 @@ impl GraphCircuit {
self.outputs = datum.0;
data.output_data = datum.1.into();
} else {
self.load_outputs(data).await?;
self.outputs = self
.process_witness_source(
&data.input_data,
self.model.graph.input_shapes(),
self.model.graph.get_output_scales(),
)
.await?;
}
// Save the updated GraphInput struct to the data_path
data.save(test_on_chain_data.data)?;
@@ -758,7 +817,7 @@ impl Circuit<Fp> for GraphCircuit {
let mut inputs = self
.inputs
.iter()
.map(|i| ValTensor::from(<Tensor<i128> as Into<Tensor<Value<Fp>>>>::into(i.clone())))
.map(|i| ValTensor::from(i.map(|x| Value::known(x))))
.collect::<Vec<ValTensor<Fp>>>();
let mut instance_offset = ModuleInstanceOffset::new();

View File

@@ -8,6 +8,7 @@ use crate::circuit::hybrid::HybridOp;
use crate::circuit::region::RegionCtx;
use crate::circuit::Input;
use crate::circuit::Unknown;
use crate::fieldutils::felt_to_i128;
use crate::{
circuit::{lookup::LookupOp, BaseConfig as PolyConfig, CheckMode, Op},
commands::RunArgs,
@@ -40,7 +41,7 @@ use tract_onnx::prelude::Framework;
#[derive(Clone, Debug)]
pub struct ForwardResult {
/// The outputs of the forward pass.
pub outputs: Vec<Tensor<i128>>,
pub outputs: Vec<Tensor<Fp>>,
/// The maximum value of any input to a lookup operation.
pub max_lookup_inputs: i128,
}
@@ -182,6 +183,14 @@ impl ParsedNodes {
.collect_vec()
}
/// Returns the fixed point scale of the computational graph's inputs
pub fn get_input_scales(&self) -> Vec<u32> {
let input_nodes = self.inputs.iter();
input_nodes
.flat_map(|o| self.nodes.get(o).unwrap().out_scales())
.collect_vec()
}
/// Returns the fixed point scale of the computational graph's outputs
pub fn get_output_scales(&self) -> Vec<u32> {
let output_nodes = self.outputs.iter();
@@ -280,8 +289,8 @@ impl Model {
/// * `reader` - A reader for an Onnx file.
/// * `model_inputs` - A vector of [Tensor]s to use as inputs to the model.
/// * `run_args` - [RunArgs]
pub fn forward(&self, model_inputs: &[Tensor<i128>]) -> Result<ForwardResult, Box<dyn Error>> {
let mut results: BTreeMap<&usize, Tensor<i128>> = BTreeMap::new();
pub fn forward(&self, model_inputs: &[Tensor<Fp>]) -> Result<ForwardResult, Box<dyn Error>> {
let mut results: BTreeMap<&usize, Tensor<Fp>> = BTreeMap::new();
let mut max_lookup_inputs = 0;
let mut input_idx = 0;
for (idx, n) in self.graph.nodes.iter() {
@@ -305,7 +314,7 @@ impl Model {
if !n.required_lookups().is_empty() {
let mut max = 0;
for i in &inputs {
max = max.max(i.iter().map(|x| x.abs()).max().unwrap());
max = max.max(i.iter().map(|x| felt_to_i128(*x).abs()).max().unwrap());
}
max_lookup_inputs = max_lookup_inputs.max(max);
}

View File

@@ -2,7 +2,6 @@ use crate::circuit::modules::elgamal::{ElGamalConfig, ElGamalGadget, ElGamalVari
use crate::circuit::modules::poseidon::spec::{PoseidonSpec, POSEIDON_RATE, POSEIDON_WIDTH};
use crate::circuit::modules::poseidon::{PoseidonChip, PoseidonConfig};
use crate::circuit::modules::Module;
use crate::fieldutils::i128_to_felt;
use crate::tensor::{Tensor, ValTensor, ValType};
use halo2_proofs::circuit::{Layouter, Value};
use halo2_proofs::plonk::{ConstraintSystem, Error};
@@ -320,7 +319,6 @@ impl GraphModules {
instance_offset: &mut [usize],
) -> Result<(), Error> {
// reserve module 0 for ... modules
values.iter_mut().for_each(|x| {
// hash the input and replace the constrained cells in the input
let cloned_x = (*x).clone();
@@ -414,7 +412,7 @@ impl GraphModules {
/// Run forward pass
pub fn forward(
inputs: &[Tensor<i128>],
inputs: &[Tensor<Fp>],
element_visibility: Visibility,
) -> Result<ModuleForwardResult, Box<dyn std::error::Error>> {
let mut rng = &mut rand::thread_rng();
@@ -423,8 +421,7 @@ impl GraphModules {
if element_visibility.is_hashed() {
let field_elements = inputs.iter().fold(vec![], |mut acc, x| {
let field_elements = x.iter().map(|x| i128_to_felt::<Fp>(*x)).collect();
let res = ModulePoseidon::run(field_elements).unwrap()[0].clone();
let res = ModulePoseidon::run(x.to_vec()).unwrap()[0].clone();
acc.extend(res);
acc
});
@@ -435,9 +432,7 @@ impl GraphModules {
let variables = ElGamalVariables::gen_random(&mut rng);
let elgamal_outputs = inputs.iter().fold(vec![], |mut acc: Vec<Vec<Fp>>, x| {
let field_elements = x.iter().map(|x| i128_to_felt::<Fp>(*x)).collect();
let ciphers = ElGamalGadget::run((field_elements, variables.clone())).unwrap();
let ciphers = ElGamalGadget::run((x.to_vec(), variables.clone())).unwrap();
if acc.is_empty() {
ciphers
} else {

View File

@@ -35,9 +35,9 @@ use tract_onnx::tract_hir::{
/// * `dims` - the dimensionality of the resulting [Tensor].
/// * `shift` - offset used in the fixed point representation.
/// * `scale` - `2^scale` used in the fixed point representation.
pub fn quantize_float(elem: &f32, shift: f32, scale: u32) -> Result<i128, TensorError> {
let mult = scale_to_multiplier(scale) as f32;
let max_value = ((i128::MAX as f32 - shift) / mult).round(); // the maximum value that can be represented w/o sig bit truncation
pub fn quantize_float(elem: &f64, shift: f64, scale: u32) -> Result<i128, TensorError> {
let mult = scale_to_multiplier(scale);
let max_value = ((i128::MAX as f64 - shift) / mult).round(); // the maximum value that can be represented w/o sig bit truncation
if *elem > max_value {
return Err(TensorError::SigBitTruncationError);
@@ -774,14 +774,16 @@ pub fn tensor_to_valtensor<F: PrimeField + TensorType + PartialOrd>(
Visibility::Public => const_value
.map(|x| {
crate::tensor::ValType::Constant(crate::fieldutils::i128_to_felt::<F>(
quantize_float(&x, 0.0, scale).unwrap(),
quantize_float(&x.into(), 0.0, scale).unwrap(),
))
})
.into(),
Visibility::Private | Visibility::Hashed | Visibility::Encrypted => const_value
.map(|x| {
crate::tensor::ValType::Value(halo2_proofs::circuit::Value::known(
crate::fieldutils::i128_to_felt::<F>(quantize_float(&x, 0.0, scale).unwrap()),
crate::fieldutils::i128_to_felt::<F>(
quantize_float(&x.into(), 0.0, scale).unwrap(),
),
))
})
.into(),

View File

@@ -195,8 +195,39 @@ impl<F: PrimeField + TensorType + PartialOrd> ValTensor<F> {
}
}
/// Fetch the underlying [Tensor] of field elements.
pub fn get_felt_evals(&self) -> Result<Tensor<F>, Box<dyn Error>> {
let mut felt_evals: Vec<F> = vec![];
match self {
ValTensor::Value {
inner: v, dims: _, ..
} => {
// we have to push to an externally created vector or else vaf.map() returns an evaluation wrapped in Value<> (which we don't want)
let _ = v.map(|vaf| match vaf {
ValType::Value(v) => v.map(|f| {
felt_evals.push(f);
}),
ValType::AssignedValue(v) => v.map(|f| {
felt_evals.push(f.evaluate());
}),
ValType::PrevAssigned(v) => v.value_field().map(|f| {
felt_evals.push(f.evaluate());
}),
ValType::Constant(v) => {
felt_evals.push(v);
Value::unknown()
}
});
}
_ => return Err(Box::new(TensorError::WrongMethod)),
};
let res: Tensor<F> = felt_evals.into_iter().into();
Ok(res)
}
/// Calls `int_evals` on the inner tensor.
pub fn get_int_evals(&self) -> Result<Vec<i128>, Box<dyn Error>> {
pub fn get_int_evals(&self) -> Result<Tensor<i128>, Box<dyn Error>> {
// finally convert to vector of integers
let mut integer_evals: Vec<i128> = vec![];
match self {
@@ -222,7 +253,7 @@ impl<F: PrimeField + TensorType + PartialOrd> ValTensor<F> {
}
_ => return Err(Box::new(TensorError::WrongMethod)),
};
Ok(integer_evals)
Ok(integer_evals.into_iter().into())
}
/// Calls `get_slice` on the inner tensor.

View File

@@ -174,8 +174,9 @@ impl VarTensor {
offset: usize,
constant: F
) -> Result<AssignedCell<F, F>, halo2_proofs::plonk::Error>{
let (x, y) = self.cartesian_coord(offset);
match &self {
VarTensor::Advice { inner: advices, .. } => {
region.assign_advice_from_constant(|| "constant", advices[x], y, constant)
@@ -184,10 +185,10 @@ impl VarTensor {
region.assign_fixed(|| "constant", fixed[x], y, || Value::known(constant))
}
_ => panic!()
}}
/// Assigns [ValTensor] to the columns of the inner tensor.
pub fn assign<F: PrimeField + TensorType + PartialOrd>(
&self,

View File

@@ -209,7 +209,7 @@ pub fn prove_wasm(
.unwrap();
// prep public inputs
circuit.load_data(&data).unwrap();
circuit.load_graph_witness(&data).unwrap();
let public_inputs = circuit.prepare_public_inputs(&data).unwrap();
let strategy = KZGSingleStrategy::new(&params);

View File

@@ -3,8 +3,8 @@
mod native_tests {
use core::panic;
use ezkl_lib::graph::input::GraphInput;
use ezkl_lib::graph::DataSource;
use ezkl_lib::graph::GraphWitness;
use lazy_static::lazy_static;
use std::env::var;
use std::process::Command;
@@ -126,7 +126,7 @@ mod native_tests {
assert!(status.success());
let data = GraphWitness::from_path(format!("{}/{}/input.json", test_dir, test).into())
let data = GraphInput::from_path(format!("{}/{}/input.json", test_dir, test).into())
.expect("failed to load input data");
let input_data = match data.input_data {
@@ -134,25 +134,12 @@ mod native_tests {
DataSource::OnChain(_) => panic!("Only File data sources support batching"),
};
let output_data = match data.output_data {
DataSource::File(data) => data,
DataSource::OnChain(_) => panic!("Only File data sources support batching"),
};
let duplicated_input_data: Vec<Vec<f32>> = input_data
let duplicated_input_data: Vec<Vec<f64>> = input_data
.iter()
.map(|data| (0..num_batches).flat_map(|_| data.clone()).collect())
.collect();
let duplicated_output_data: Vec<Vec<f32>> = output_data
.iter()
.map(|data| (0..num_batches).flat_map(|_| data.clone()).collect())
.collect();
let duplicated_data = GraphWitness::new(
DataSource::File(duplicated_input_data),
DataSource::File(duplicated_output_data),
);
let duplicated_data = GraphInput::new(DataSource::File(duplicated_input_data));
let res =
duplicated_data.save(format!("{}/{}/input.json", test_dir, output_dir).into());

View File

@@ -153,10 +153,8 @@ def test_forward():
with open(output_path, "r") as f:
data = json.load(f)
assert data["input_data"] == res["input_data"] == [[0.05326242372393608, 0.07497056573629379, 0.05235547572374344, 0.028825461864471436, 0.05848702788352966,
0.008225822821259499, 0.07530029118061066, 0.0821458026766777, 0.06227986887097359, 0.024306034669280052, 0.05793173983693123, 0.040442030876874924]]
assert data["output_data"] == res["output_data"] == [[0.05322265625, 0.12841796875, 0.0751953125, 0.10546875, 0.20947265625, 0.10400390625, 0.05224609375, 0.0810546875, 0.02880859375, 0.05859375, 0.06689453125,
0.00830078125, 0.1337890625, 0.22412109375, 0.09033203125, 0.0751953125, 0.1572265625, 0.08203125, 0.0625, 0.0869140625, 0.0244140625, 0.12060546875, 0.185546875, 0.06494140625, 0.05810546875, 0.0986328125, 0.04052734375]]
assert data["input_data"] == res["input_data"]
assert data["output_data"] == res["output_data"]
assert data["processed_inputs"]["poseidon_hash"] == res["processed_inputs"]["poseidon_hash"] == [[
8270957937025516140, 11801026918842104328, 2203849898884507041, 140307258138425306]]
@@ -452,8 +450,16 @@ async def aggregate_and_verify_aggr():
proof_path = os.path.join(folder_path, '1l_relu.pf')
output_path = os.path.join(
folder_path,
'1l_relu_aggr_witness.json'
)
res = ezkl_lib.gen_witness(data_path, model_path,
output_path, settings_path=settings_path)
ezkl_lib.prove(
data_path,
output_path,
model_path,
pk_path,
proof_path,
@@ -540,8 +546,16 @@ async def evm_aggregate_and_verify_aggr():
proof_path = os.path.join(folder_path, '1l_relu.pf')
output_path = os.path.join(
folder_path,
'1l_relu_aggr_evm_witness.json'
)
res = ezkl_lib.gen_witness(data_path, model_path,
output_path, settings_path=settings_path)
ezkl_lib.prove(
data_path,
output_path,
model_path,
pk_path,
proof_path,

View File

@@ -1,407 +1,52 @@
{
"input_data": [
[
1.5417295,
0.5346153,
1.2172532
[
6425625360762666998,
7924344314350639699,
14762033076929465436,
2023505479389396574
],
[
12436184717236109307,
3962172157175319849,
7381016538464732718,
1011752739694698287
],
[
12436184717236109307,
3962172157175319849,
7381016538464732718,
1011752739694698287
]
]
],
"output_data": [
[
0.0,
0.0,
0.0,
0.0
[
0,
0,
0,
0
],
[
0,
0,
0,
0
],
[
0,
0,
0,
0
],
[
0,
0,
0,
0
]
]
],
"processed_inputs": {
"poseidon_hash": [
[
1205357537771999130,
5984390958877082714,
9820729522268186501,
2234099058051285368
]
],
"elgmal_results": {
"variables": {
"r": [
8241161763432706976,
666255099567061377,
14805312420564963680,
1204508807060214874
],
"pk": {
"x": [
15190346206350482033,
5345466492465898205,
15673625536425086517,
2946938976802915929
],
"y": [
3849689718459676759,
2554904641803052543,
17914557513841193544,
2568907273734445593
]
},
"sk": [
11176153975878706009,
4400691958827003321,
2763850695090718903,
2596642706525330605
],
"window_size": 1,
"aux_generator": {
"x": [
10023487266988000827,
5757424535551704030,
14065811819017675132,
708744013093495848
],
"y": [
3035630535707533851,
15084156752173132954,
4896068836772202981,
3481542306236198361
]
}
},
"ciphertexts": [
[],
[
[
17425460071899514968,
1167359589260142187,
6756444933589811955,
2891408151901966517
],
[
829458805722111377,
17250282503078918728,
7016627888695675713,
2925376423043353791
]
],
[
[
7811839396291636301,
16626502080672679222,
7596684602016322355,
1171033058875794198
],
[
13822398752765078610,
12664329923497359372,
215668063551589637,
159280319181095911
],
[
13822398752765078610,
12664329923497359372,
215668063551589637,
159280319181095911
]
],
[
[
9283435499189890162,
12255697202589512688,
8119465263255675837,
3239333981286649072
]
]
]
}
},
"processed_params": {
"poseidon_hash": [
[
2438338217603384681,
4147850557857422208,
14984698365426657752,
2960724354489209719
]
],
"elgamal": {
"variables": {
"r": [
1200181915139776053,
13180175622696807913,
14476695540645645352,
1345741327868623972
],
"pk": {
"x": [
9797896387342524947,
15035349603715646648,
13792633755538567363,
1133541516276092058
],
"y": [
9476212339347896300,
11232279511380632903,
13562657712233073770,
2203529712687503461
]
},
"sk": [
15504567943908784228,
16286748745616108518,
11666349446092175813,
739552556110375646
],
"window_size": 1,
"aux_generator": {
"x": [
3146347569958584183,
5994772996379248674,
12027192890968946425,
250337585815942863
],
"y": [
15759508736901521891,
13730739491295063583,
12178033622663772459,
328932231313987632
]
}
},
"ciphertexts": [
[],
[
[
3687758631139381896,
16990020959593879569,
11442997537587059238,
2601192457456256832
],
[
9698149576377935801,
11180115031102909471,
10273504557556343240,
1717383746578603920
]
],
[
[
11464092982624561027,
14296265120971431549,
9872837265046557632,
798828383313516575
],
[
3919368951425050505,
13231007347102958053,
15773012677856519663,
3274073910421788953
],
[
11464092982624561027,
14296265120971431549,
9872837265046557632,
798828383313516575
],
[
11464092982624561027,
14296265120971431549,
9872837265046557632,
798828383313516575
],
[
3919368951425050505,
13231007347102958053,
15773012677856519663,
3274073910421788953
],
[
11464092982624561027,
14296265120971431549,
9872837265046557632,
798828383313516575
],
[
11464092982624561027,
14296265120971431549,
9872837265046557632,
798828383313516575
],
[
11464092982624561027,
14296265120971431549,
9872837265046557632,
798828383313516575
],
[
3919368951425050505,
13231007347102958053,
15773012677856519663,
3274073910421788953
],
[
11464092982624561027,
14296265120971431549,
9872837265046557632,
798828383313516575
],
[
11464092982624561027,
14296265120971431549,
9872837265046557632,
798828383313516575
],
[
11464092982624561027,
14296265120971431549,
9872837265046557632,
798828383313516575
],
[
11464092982624561027,
14296265120971431549,
9872837265046557632,
798828383313516575
],
[
5453533626151118718,
18258437278146751399,
17253853803511290350,
1810581123008214862
],
[
11464092982624561027,
14296265120971431549,
9872837265046557632,
798828383313516575
],
[
11464092982624561027,
14296265120971431549,
9872837265046557632,
798828383313516575
],
[
11464092982624561027,
14296265120971431549,
9872837265046557632,
798828383313516575
]
],
[
[
11835554726384377649,
3476093169659441207,
10317446023932452752,
1341378657344074082
]
]
]
}
},
"processed_outputs": {
"poseidon_hash": [
[
13525436761904490906,
5108452989262774252,
17113314644501751557,
3448479653599714018
]
],
"elgmal_results": {
"variables": {
"r": [
4178821741174955177,
14671768628465678513,
4455016481844455948,
596401571849164806
],
"pk": {
"x": [
17993908268456464107,
12767282455815797736,
6841806501674431080,
147587434280887521
],
"y": [
16992425197354392191,
9664339530271872349,
1076902239844679755,
2685162918446917537
]
},
"sk": [
1307030728222113769,
11176383346436204449,
2128349950767135257,
1684071711526912194
],
"window_size": 1,
"aux_generator": {
"x": [
4657277071614079755,
1369367041476422289,
14912396191352368123,
3115177155409298233
],
"y": [
11259391534341870480,
12985819015883975733,
5695317503694964668,
1609774512612003893
]
}
},
"ciphertexts": [
[],
[
[
310776321632342963,
8782220497026581904,
967332862939820787,
2438830089048710677
],
[
17188119484612178357,
5607787842121139250,
12990365590313846177,
2940512950316690900
]
],
[
[
11468406689166593318,
9609755014076369146,
782845172795013879,
2433980276123347322
],
[
11468406689166593318,
9609755014076369146,
782845172795013879,
2433980276123347322
],
[
11468406689166593318,
9609755014076369146,
782845172795013879,
2433980276123347322
],
[
11468406689166593318,
9609755014076369146,
782845172795013879,
2433980276123347322
]
],
[
[
15812472742830711787,
11777278746141567566,
4910275528091806968,
1659402969803748057
]
]
]
}
}
]
}