Compare commits

...

4 Commits

Author SHA1 Message Date
dante
95d4fd4a70 feat: power of 2 div using type system (#702) 2024-02-04 02:43:38 +00:00
dante
e0d3f4f145 fix: uncomparable values in acc table (#701) 2024-02-02 15:13:29 +00:00
dante
bceac2fab5 ci: make gpu tests single threaded (#700) 2024-01-31 18:19:29 +00:00
dante
04d7b5feaa chore: fold div_rebasing parameter into calibration (#699) 2024-01-31 10:03:12 +00:00
16 changed files with 266 additions and 123 deletions

View File

@@ -427,21 +427,21 @@ jobs:
crate: cargo-nextest
locked: true
- name: KZG prove and verify tests (kzg outputs)
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_kzg_output --features icicle --test-threads 2
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_kzg_output --features icicle --test-threads 1
- name: KZG prove and verify tests (public outputs + column overflow)
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_with_overflow_::w --features icicle --test-threads 2
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_with_overflow_::w --features icicle --test-threads 1
- name: KZG prove and verify tests (public outputs + fixed params + column overflow)
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_with_overflow_fixed_params_ --features icicle --test-threads 2
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_with_overflow_fixed_params_ --features icicle --test-threads 1
- name: KZG prove and verify tests (public outputs)
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_::t --features icicle --test-threads 2
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_::t --features icicle --test-threads 1
- name: KZG prove and verify tests (public outputs + column overflow)
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_::t --features icicle --test-threads 2
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_::t --features icicle --test-threads 1
- name: KZG prove and verify tests (public inputs)
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_public_input --features icicle --test-threads 2
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_public_input --features icicle --test-threads 1
- name: KZG prove and verify tests (fixed params)
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_fixed_params --features icicle --test-threads 2
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_fixed_params --features icicle --test-threads 1
- name: KZG prove and verify tests (hashed outputs)
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_hashed --features icicle --test-threads 2
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_hashed --features icicle --test-threads 1
fuzz-tests:
runs-on: ubuntu-latest-32-cores

View File

@@ -0,0 +1,39 @@
from torch import nn
import torch
import json
class Circuit(nn.Module):
def __init__(self, inplace=False):
super(Circuit, self).__init__()
def forward(self, x):
return x/ 10000
circuit = Circuit()
x = torch.empty(1, 8).random_(0, 2)
out = circuit(x)
print(out)
torch.onnx.export(circuit, x, "network.onnx",
export_params=True, # store the trained parameter weights inside the model file
opset_version=17, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names=['input'], # the model's input names
output_names=['output'], # the model's output names
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
'output': {0: 'batch_size'}})
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
data = dict(
input_data=[d1],
)
# Serialize data into file:
json.dump(data, open("input.json", 'w'))

View File

@@ -0,0 +1 @@
{"input_data": [[1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0]]}

Binary file not shown.

View File

@@ -13,6 +13,10 @@ use serde::{Deserialize, Serialize};
/// An enum representing the operations that consist of both lookups and arithmetic operations.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum HybridOp {
Div {
denom: utils::F32,
use_range_check_for_int: bool,
},
ReduceMax {
axes: Vec<usize>,
},
@@ -113,6 +117,21 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
- tensor::ops::sum(&tensor::ops::nonlinearities::leakyrelu(&inter_1, 0.0))?)?;
(res.clone(), vec![inter_1, inter_2])
}
HybridOp::Div {
denom,
use_range_check_for_int,
..
} => {
// if denom is a round number and use_range_check_for_int is true, use range check check
if denom.0.fract() == 0.0 && *use_range_check_for_int {
let divisor = Tensor::from(vec![denom.0 as i128].into_iter());
let res = crate::tensor::ops::div(&[x, divisor.clone()])?;
(res, vec![-divisor.clone(), divisor])
} else {
let res = crate::tensor::ops::nonlinearities::const_div(&x, denom.0 as f64);
(res, vec![x])
}
}
HybridOp::ReduceArgMax { dim } => {
let res = tensor::ops::argmax_axes(&x, *dim)?;
let indices = Tensor::from(0..x.dims()[*dim] as i128);
@@ -272,6 +291,13 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
fn as_string(&self) -> String {
match self {
HybridOp::Div {
denom,
use_range_check_for_int,
} => format!(
"DIV (denom={}, use_range_check_for_int={})",
denom, use_range_check_for_int
),
HybridOp::SumPool {
padding,
stride,
@@ -335,6 +361,29 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
*kernel_shape,
*normalized,
)?,
HybridOp::Div {
denom,
use_range_check_for_int,
..
} => {
if denom.0.fract() == 0.0 && *use_range_check_for_int {
layouts::div(
config,
region,
values[..].try_into()?,
i128_to_felt(denom.0 as i128),
)?
} else {
layouts::nonlinearity(
config,
region,
values.try_into()?,
&LookupOp::Div {
denom: denom.clone(),
},
)?
}
}
HybridOp::Gather { dim, constant_idx } => {
if let Some(idx) = constant_idx {
tensor::ops::gather(values[0].get_inner_tensor()?, idx, *dim)?.into()
@@ -427,11 +476,41 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
Ok(scale)
}
fn required_range_checks(&self) -> Vec<Range> {
match self {
HybridOp::Div {
denom,
use_range_check_for_int,
..
} => {
if denom.0.fract() == 0.0 && *use_range_check_for_int {
vec![(-denom.0 as i128 + 1, denom.0 as i128 - 1)]
} else {
vec![]
}
}
_ => vec![],
}
}
fn required_lookups(&self) -> Vec<LookupOp> {
match self {
HybridOp::ReduceMax { .. }
| HybridOp::ReduceMin { .. }
| HybridOp::MaxPool2d { .. } => Op::<F>::required_lookups(&LookupOp::ReLU),
HybridOp::Div {
denom,
use_range_check_for_int,
..
} => {
if denom.0.fract() == 0.0 && *use_range_check_for_int {
vec![]
} else {
vec![LookupOp::Div {
denom: denom.clone(),
}]
}
}
HybridOp::Softmax { scale, .. } => {
vec![
LookupOp::Exp { scale: *scale },

View File

@@ -33,7 +33,9 @@ pub enum PolyOp {
Sub,
Neg,
Mult,
Identity,
Identity {
out_scale: Option<crate::Scale>,
},
Reshape(Vec<usize>),
MoveAxis {
source: usize,
@@ -85,7 +87,9 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
PolyOp::Resize { .. } => "RESIZE".into(),
PolyOp::Iff => "IFF".into(),
PolyOp::Einsum { equation, .. } => format!("EINSUM {}", equation),
PolyOp::Identity => "IDENTITY".into(),
PolyOp::Identity { out_scale } => {
format!("IDENTITY (out_scale={:?})", out_scale)
}
PolyOp::Reshape(shape) => format!("RESHAPE (shape={:?})", shape),
PolyOp::Flatten(_) => "FLATTEN".into(),
PolyOp::Pad(_) => "PAD".into(),
@@ -135,7 +139,7 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
PolyOp::Resize { scale_factor } => tensor::ops::resize(&inputs[0], scale_factor),
PolyOp::Iff => tensor::ops::iff(&inputs[0], &inputs[1], &inputs[2]),
PolyOp::Einsum { equation } => tensor::ops::einsum(equation, &inputs),
PolyOp::Identity => Ok(inputs[0].clone()),
PolyOp::Identity { .. } => Ok(inputs[0].clone()),
PolyOp::Reshape(new_dims) => {
let mut t = inputs[0].clone();
t.reshape(new_dims)?;
@@ -264,7 +268,7 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
PolyOp::Mult => {
layouts::pairwise(config, region, values[..].try_into()?, BaseOp::Mult)?
}
PolyOp::Identity => layouts::identity(config, region, values[..].try_into()?)?,
PolyOp::Identity { .. } => layouts::identity(config, region, values[..].try_into()?)?,
PolyOp::Reshape(d) | PolyOp::Flatten(d) => layouts::reshape(values[..].try_into()?, d)?,
PolyOp::Pad(p) => {
if values.len() != 1 {
@@ -322,9 +326,8 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
output_scale
}
PolyOp::Add => {
let mut scale_a = 0;
let scale_b = in_scales[0];
scale_a += in_scales[1];
let scale_a = in_scales[0];
let scale_b = in_scales[1];
assert_eq!(scale_a, scale_b);
scale_a
}
@@ -336,19 +339,19 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
}
PolyOp::Reshape(_) | PolyOp::Flatten(_) => in_scales[0],
PolyOp::Pow(pow) => in_scales[0] * (*pow as crate::Scale),
PolyOp::Identity { out_scale } => out_scale.unwrap_or(in_scales[0]),
_ => in_scales[0],
};
Ok(scale)
}
fn requires_homogenous_input_scales(&self) -> Vec<usize> {
if matches!(
self,
PolyOp::Add { .. } | PolyOp::Sub | PolyOp::Concat { .. }
) {
if matches!(self, PolyOp::Add { .. } | PolyOp::Sub) {
vec![0, 1]
} else if matches!(self, PolyOp::Iff) {
vec![1, 2]
} else if matches!(self, PolyOp::Concat { .. }) {
(0..100).collect()
} else {
vec![]
}

View File

@@ -336,6 +336,9 @@ pub enum Commands {
/// max logrows to use for calibration, 26 is the max public SRS size
#[arg(long)]
max_logrows: Option<u32>,
// whether to fix the div_rebasing value truthiness during calibration. this changes how we rebase
#[arg(long)]
div_rebasing: Option<bool>,
},
/// Generates a dummy SRS

View File

@@ -178,6 +178,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
scales,
scale_rebase_multiplier,
max_logrows,
div_rebasing,
} => calibrate(
model,
data,
@@ -186,6 +187,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
lookup_safety_margin,
scales,
scale_rebase_multiplier,
div_rebasing,
max_logrows,
)
.map(|e| serde_json::to_string(&e).unwrap()),
@@ -628,6 +630,10 @@ pub(crate) async fn gen_witness(
if let Some(output_path) = output {
serde_json::to_writer(&File::create(output_path)?, &witness)?;
}
// print the witness in debug
debug!("witness: \n {}", witness.as_json()?.to_colored_json_auto()?);
Ok(witness)
}
@@ -735,22 +741,22 @@ impl AccuracyResults {
let median_error = errors[errors.len() / 2];
let max_error = *errors
.iter()
.max_by(|a, b| a.partial_cmp(b).unwrap())
.max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.unwrap();
let min_error = *errors
.iter()
.min_by(|a, b| a.partial_cmp(b).unwrap())
.min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.unwrap();
let mean_abs_error = abs_errors.iter().sum::<f32>() / abs_errors.len() as f32;
let median_abs_error = abs_errors[abs_errors.len() / 2];
let max_abs_error = *abs_errors
.iter()
.max_by(|a, b| a.partial_cmp(b).unwrap())
.max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.unwrap();
let min_abs_error = *abs_errors
.iter()
.min_by(|a, b| a.partial_cmp(b).unwrap())
.min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.unwrap();
let mean_squared_error = squared_errors.iter().sum::<f32>() / squared_errors.len() as f32;
@@ -782,6 +788,7 @@ pub(crate) fn calibrate(
lookup_safety_margin: i128,
scales: Option<Vec<crate::Scale>>,
scale_rebase_multiplier: Vec<u32>,
div_rebasing: Option<bool>,
max_logrows: Option<u32>,
) -> Result<GraphSettings, Box<dyn Error>> {
use std::collections::HashMap;
@@ -825,6 +832,12 @@ pub(crate) fn calibrate(
}
};
let div_rebasing = if let Some(div_rebasing) = div_rebasing {
vec![div_rebasing]
} else {
vec![true, false]
};
let mut found_params: Vec<GraphSettings> = vec![];
// 2 x 2 grid
@@ -862,15 +875,21 @@ pub(crate) fn calibrate(
.map(|(a, b)| (*a, *b))
.collect::<Vec<((crate::Scale, crate::Scale), u32)>>();
let range_grid = range_grid
.iter()
.cartesian_product(div_rebasing.iter())
.map(|(a, b)| (*a, *b))
.collect::<Vec<(((crate::Scale, crate::Scale), u32), bool)>>();
let mut forward_pass_res = HashMap::new();
let pb = init_bar(range_grid.len() as u64);
pb.set_message("calibrating...");
for ((input_scale, param_scale), scale_rebase_multiplier) in range_grid {
for (((input_scale, param_scale), scale_rebase_multiplier), div_rebasing) in range_grid {
pb.set_message(format!(
"input scale: {}, param scale: {}, scale rebase multiplier: {}",
input_scale, param_scale, scale_rebase_multiplier
"input scale: {}, param scale: {}, scale rebase multiplier: {}, div rebasing: {}",
input_scale, param_scale, scale_rebase_multiplier, div_rebasing
));
#[cfg(unix)]
@@ -890,6 +909,7 @@ pub(crate) fn calibrate(
input_scale,
param_scale,
scale_rebase_multiplier,
div_rebasing,
..settings.run_args.clone()
};
@@ -964,6 +984,7 @@ pub(crate) fn calibrate(
let found_run_args = RunArgs {
input_scale: new_settings.run_args.input_scale,
param_scale: new_settings.run_args.param_scale,
div_rebasing: new_settings.run_args.div_rebasing,
lookup_range: new_settings.run_args.lookup_range,
logrows: new_settings.run_args.logrows,
scale_rebase_multiplier: new_settings.run_args.scale_rebase_multiplier,

View File

@@ -968,8 +968,8 @@ impl GraphCircuit {
lookup_safety_margin * max_lookup_inputs,
);
if lookup_safety_margin == 1 {
margin.0 += 1;
margin.1 += 1;
margin.0 += 4;
margin.1 += 4;
}
margin

View File

@@ -591,6 +591,8 @@ impl Model {
inputs.iter().map(|x| x.dims()).collect::<Vec<_>>()
);
debug!("input nodes: {:?}", n.inputs());
if n.is_lookup() {
let (mut min, mut max) = (0, 0);
for i in &inputs {

View File

@@ -12,8 +12,6 @@ use crate::circuit::Constant;
use crate::circuit::Input;
use crate::circuit::Op;
use crate::circuit::Unknown;
use crate::fieldutils::felt_to_i128;
use crate::fieldutils::i128_to_felt;
#[cfg(not(target_arch = "wasm32"))]
use crate::graph::new_op_from_onnx;
use crate::tensor::Tensor;
@@ -126,14 +124,14 @@ impl Op<Fp> for Rescaled {
pub struct RebaseScale {
/// The operation that has to be rescaled.
pub inner: Box<SupportedOp>,
/// the multiplier applied to the node output
pub multiplier: f64,
/// rebase op
pub rebase_op: HybridOp,
/// scale being rebased to
pub target_scale: i32,
/// The original scale of the operation's inputs.
pub original_scale: i32,
/// if true then the operation is a multiplicative division
pub div_rebasing: bool,
/// multiplier
pub multiplier: f64,
}
impl RebaseScale {
@@ -152,20 +150,27 @@ impl RebaseScale {
let multiplier =
scale_to_multiplier(op_out_scale - global_scale * scale_rebase_multiplier as i32);
if let Some(op) = inner.get_rebased() {
let multiplier = op.multiplier * multiplier;
SupportedOp::RebaseScale(RebaseScale {
inner: op.inner.clone(),
target_scale: op.target_scale,
multiplier: op.multiplier * multiplier,
multiplier: multiplier,
rebase_op: HybridOp::Div {
denom: crate::circuit::utils::F32((multiplier) as f32),
use_range_check_for_int: !div_rebasing,
},
original_scale: op.original_scale,
div_rebasing,
})
} else {
SupportedOp::RebaseScale(RebaseScale {
inner: Box::new(inner),
target_scale: global_scale * scale_rebase_multiplier as i32,
multiplier,
rebase_op: HybridOp::Div {
denom: crate::circuit::utils::F32(multiplier as f32),
use_range_check_for_int: !div_rebasing,
},
original_scale: op_out_scale,
div_rebasing,
})
}
} else {
@@ -183,12 +188,16 @@ impl RebaseScale {
if (op_out_scale < (target_scale)) && !inner.is_constant() && !inner.is_input() {
let multiplier = scale_to_multiplier(op_out_scale - target_scale);
if let Some(op) = inner.get_rebased() {
let multiplier = op.multiplier * multiplier;
SupportedOp::RebaseScale(RebaseScale {
inner: op.inner.clone(),
target_scale: op.target_scale,
multiplier: op.multiplier * multiplier,
multiplier,
original_scale: op.original_scale,
div_rebasing,
rebase_op: HybridOp::Div {
denom: crate::circuit::utils::F32((multiplier) as f32),
use_range_check_for_int: !div_rebasing,
},
})
} else {
SupportedOp::RebaseScale(RebaseScale {
@@ -196,22 +205,16 @@ impl RebaseScale {
target_scale,
multiplier,
original_scale: op_out_scale,
div_rebasing,
rebase_op: HybridOp::Div {
denom: crate::circuit::utils::F32(multiplier as f32),
use_range_check_for_int: !div_rebasing,
},
})
}
} else {
inner
}
}
/// Calculate the require range bracket for the operation
fn range_bracket(&self) -> i128 {
if self.div_rebasing {
0
} else {
self.multiplier as i128 - 1
}
}
}
impl Op<Fp> for RebaseScale {
@@ -220,28 +223,19 @@ impl Op<Fp> for RebaseScale {
}
fn f(&self, x: &[Tensor<Fp>]) -> Result<crate::circuit::ForwardResult<Fp>, TensorError> {
let mut res = Op::<Fp>::f(&*self.inner, x)?;
if self.div_rebasing {
let ri = res.output.map(felt_to_i128);
let rescaled = crate::tensor::ops::nonlinearities::const_div(&ri, self.multiplier);
res.output = rescaled.map(i128_to_felt);
res.intermediate_lookups.push(ri);
} else {
let ri = res.output.map(felt_to_i128);
let divisor = Tensor::from(vec![self.multiplier as i128].into_iter());
let rescaled = crate::tensor::ops::div(&[ri, divisor.clone()])?;
res.output = rescaled.map(i128_to_felt);
res.intermediate_lookups.extend([-divisor.clone(), divisor]);
}
let rebase_res = Op::<Fp>::f(&self.rebase_op, &[res.output])?;
res.output = rebase_res.output;
res.intermediate_lookups
.extend(rebase_res.intermediate_lookups);
Ok(res)
}
fn as_string(&self) -> String {
format!(
"REBASED (div={:?}, div_r={}) ({})",
"REBASED (div={:?}, rebasing_op={}) ({})",
self.multiplier,
self.div_rebasing,
<HybridOp as Op<Fp>>::as_string(&self.rebase_op),
self.inner.as_string()
)
}
@@ -252,20 +246,13 @@ impl Op<Fp> for RebaseScale {
fn required_lookups(&self) -> Vec<LookupOp> {
let mut lookups: Vec<LookupOp> = self.inner.required_lookups();
if self.div_rebasing {
lookups.push(LookupOp::Div {
denom: crate::circuit::utils::F32(self.multiplier as f32),
});
}
lookups.extend(Op::<Fp>::required_lookups(&self.rebase_op));
lookups
}
fn required_range_checks(&self) -> Vec<crate::circuit::table::Range> {
let mut range_checks = self.inner.required_range_checks();
if !self.div_rebasing {
let bracket = self.range_bracket();
range_checks.push((-bracket, bracket));
}
range_checks.extend(Op::<Fp>::required_range_checks(&self.rebase_op));
range_checks
}
@@ -278,25 +265,8 @@ impl Op<Fp> for RebaseScale {
let original_res = self
.inner
.layout(config, region, values)?
.ok_or("no layout")?;
if !self.div_rebasing {
Ok(Some(crate::circuit::layouts::div(
config,
region,
&[original_res],
Fp::from(self.multiplier as u64),
)?))
} else {
Ok(Some(crate::circuit::layouts::nonlinearity(
config,
region,
&[original_res],
&LookupOp::Div {
denom: crate::circuit::utils::F32(self.multiplier as f32),
},
)?))
}
.ok_or("no inner layout")?;
self.rebase_op.layout(config, region, &[original_res])
}
fn clone_dyn(&self) -> Box<dyn Op<Fp>> {
@@ -584,8 +554,6 @@ impl Node {
symbol_values: &SymbolValues,
div_rebasing: bool,
) -> Result<Self, Box<dyn Error>> {
use log::warn;
trace!("Create {:?}", node);
trace!("Create op {:?}", node.op);
@@ -678,8 +646,6 @@ impl Node {
input_node.bump_scale(out_scale);
in_scales[input] = out_scale;
}
} else {
warn!("input {} not found for rescaling, skipping ...", input);
}
}

View File

@@ -261,7 +261,9 @@ pub fn new_op_from_onnx(
inputs[index].bump_scale(scale);
c.rebase_scale(scale)?;
inputs[index].replace_opkind(SupportedOp::Constant(c.clone()));
Ok(SupportedOp::Linear(PolyOp::Identity))
Ok(SupportedOp::Linear(PolyOp::Identity {
out_scale: Some(scale),
}))
} else {
Ok(default_op)
}
@@ -282,8 +284,8 @@ pub fn new_op_from_onnx(
"shift left".to_string(),
)));
}
SupportedOp::Nonlinear(LookupOp::Div {
denom: crate::circuit::utils::F32(1.0 / 2.0f32.powf(raw_values[0])),
SupportedOp::Linear(PolyOp::Identity {
out_scale: Some(input_scales[0] - raw_values[0] as i32),
})
} else {
return Err(Box::new(GraphError::OpMismatch(
@@ -304,8 +306,8 @@ pub fn new_op_from_onnx(
"shift right".to_string(),
)));
}
SupportedOp::Nonlinear(LookupOp::Div {
denom: crate::circuit::utils::F32(2.0f32.powf(raw_values[0])),
SupportedOp::Linear(PolyOp::Identity {
out_scale: Some(input_scales[0] + raw_values[0] as i32),
})
} else {
return Err(Box::new(GraphError::OpMismatch(
@@ -559,6 +561,8 @@ pub fn new_op_from_onnx(
_ => return Err(Box::new(GraphError::UnsupportedDataType)),
};
// if all raw_values are round then set scale to 0
// Quantize the raw value
let quantized_value =
quantize_tensor(raw_value.clone(), constant_scale, param_visibility)?;
@@ -665,8 +669,10 @@ pub fn new_op_from_onnx(
if unit == 0. {
SupportedOp::Nonlinear(LookupOp::ReLU)
} else {
// get the non-constant index
let non_const_idx = if const_idx == 0 { 1 } else { 0 };
SupportedOp::Nonlinear(LookupOp::Max {
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
scale: scale_to_multiplier(inputs[non_const_idx].out_scales()[0]).into(),
a: crate::circuit::utils::F32(unit),
})
}
@@ -707,8 +713,11 @@ pub fn new_op_from_onnx(
deleted_indices.push(const_idx);
}
// get the non-constant index
let non_const_idx = if const_idx == 0 { 1 } else { 0 };
SupportedOp::Nonlinear(LookupOp::Min {
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
scale: scale_to_multiplier(inputs[non_const_idx].out_scales()[0]).into(),
a: crate::circuit::utils::F32(unit),
})
} else {
@@ -718,7 +727,7 @@ pub fn new_op_from_onnx(
"Recip" => {
let in_scale = inputs[0].out_scales()[0];
// If the input scale is larger than the params scale
let scale_diff = std::cmp::max(scales.input, scales.params) - inputs[0].out_scales()[0];
let scale_diff = scales.get_max() - inputs[0].out_scales()[0];
let additional_scale = if scale_diff > 0 {
scale_to_multiplier(scale_diff)
} else {
@@ -751,7 +760,9 @@ pub fn new_op_from_onnx(
"Scan" => {
return Err("scan should never be analyzed explicitly".into());
}
"QuantizeLinearU8" | "DequantizeLinearF32" => SupportedOp::Linear(PolyOp::Identity),
"QuantizeLinearU8" | "DequantizeLinearF32" => {
SupportedOp::Linear(PolyOp::Identity { out_scale: None })
}
"Abs" => SupportedOp::Nonlinear(LookupOp::Abs),
"Neg" => SupportedOp::Linear(PolyOp::Neg),
"Sigmoid" => SupportedOp::Nonlinear(LookupOp::Sigmoid {
@@ -856,11 +867,11 @@ pub fn new_op_from_onnx(
}),
)?
} else {
SupportedOp::Linear(PolyOp::Identity)
SupportedOp::Linear(PolyOp::Identity { out_scale: None })
}
}
DatumType::F16 | DatumType::F32 | DatumType::F64 => {
SupportedOp::Linear(PolyOp::Identity)
SupportedOp::Linear(PolyOp::Identity { out_scale: None })
}
_ => return Err(Box::new(GraphError::UnsupportedDataType)),
}
@@ -885,12 +896,15 @@ pub fn new_op_from_onnx(
let const_idx = const_idx[0];
if let Some(c) = inputs[const_idx].opkind().get_mutable_constant() {
if c.raw_values.len() == 1 && c.raw_values[0] < 1. {
inputs[const_idx].decrement_use();
deleted_indices.push(const_idx);
op = SupportedOp::Nonlinear(LookupOp::Div {
// we invert the constant for division
denom: crate::circuit::utils::F32(1. / c.raw_values[0]),
})
// if not divisible by 2 then we need to add a range check
let raw_values = 1.0 / c.raw_values[0];
if raw_values.log2().fract() == 0.0 {
inputs[const_idx].decrement_use();
deleted_indices.push(const_idx);
op = SupportedOp::Linear(PolyOp::Identity {
out_scale: Some(input_scales[0] + raw_values.log2() as i32),
});
}
}
}
}

View File

@@ -237,6 +237,11 @@ impl VarScales {
std::cmp::max(self.input, self.params)
}
///
pub fn get_min(&self) -> crate::Scale {
std::cmp::min(self.input, self.params)
}
/// Place in [VarScales] struct.
pub fn from_args(args: &RunArgs) -> Result<Self, Box<dyn Error>> {
Ok(Self {

View File

@@ -521,6 +521,7 @@ fn gen_settings(
scales = None,
scale_rebase_multiplier = DEFAULT_SCALE_REBASE_MULTIPLIERS.split(",").map(|x| x.parse().unwrap()).collect(),
max_logrows = None,
div_rebasing = None,
))]
fn calibrate_settings(
data: PathBuf,
@@ -531,6 +532,7 @@ fn calibrate_settings(
scales: Option<Vec<crate::Scale>>,
scale_rebase_multiplier: Vec<u32>,
max_logrows: Option<u32>,
div_rebasing: Option<bool>,
) -> Result<bool, PyErr> {
crate::execute::calibrate(
model,
@@ -540,6 +542,7 @@ fn calibrate_settings(
lookup_safety_margin,
scales,
scale_rebase_multiplier,
div_rebasing,
max_logrows,
)
.map_err(|e| {

View File

@@ -182,12 +182,13 @@ mod native_tests {
"mnist_gan",
];
const ACCURACY_CAL_TESTS: [&str; 5] = [
const ACCURACY_CAL_TESTS: [&str; 6] = [
"accuracy",
"1l_mlp",
"4l_relu_conv_fc",
"1l_elu",
"1l_prelu",
"1l_tiny_div",
];
const TESTS: [&str; 77] = [
@@ -489,7 +490,7 @@ mod native_tests {
test_dir.close().unwrap();
}
seq!(N in 0..=4 {
seq!(N in 0..=5 {
#(#[test_case(ACCURACY_CAL_TESTS[N])])*
fn mock_accuracy_cal_tests(test: &str) {
crate::native_tests::init_binary();
@@ -836,10 +837,10 @@ mod native_tests {
let test_dir = TempDir::new(test).unwrap();
env_logger::init();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
kzg_prove_and_verify(path, test.to_string(), "safe", "private", "private", "public", 1, Some(vec![0,1]), true, "single");
kzg_prove_and_verify(path, test.to_string(), "safe", "private", "private", "public", 1, None, true, "single");
#[cfg(not(feature = "icicle"))]
run_js_tests(path, test.to_string(), "testWasm");
test_dir.close().unwrap();
// test_dir.close().unwrap();
}
#(#[test_case(WASM_TESTS[N])])*
@@ -849,7 +850,7 @@ mod native_tests {
let test_dir = TempDir::new(test).unwrap();
env_logger::init();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
kzg_prove_and_verify(path, test.to_string(), "safe", "private", "fixed", "public", 1, Some(vec![0,1]), true, "single");
kzg_prove_and_verify(path, test.to_string(), "safe", "private", "fixed", "public", 1, None, true, "single");
#[cfg(not(feature = "icicle"))]
run_js_tests(path, test.to_string(), "testWasm");
test_dir.close().unwrap();
@@ -865,7 +866,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
kzg_prove_and_verify(path, test.to_string(), "unsafe", "private", "fixed", "public", 1, Some(vec![0,6]), false, "single");
kzg_prove_and_verify(path, test.to_string(), "unsafe", "private", "fixed", "public", 1, None, false, "single");
test_dir.close().unwrap();
}
@@ -875,7 +876,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "private", "fixed", "public", 1, "resources", Some(vec![0,6]));
mock(path, test.to_string(), "private", "fixed", "public", 1, "resources", None);
test_dir.close().unwrap();
}
});
@@ -2032,7 +2033,7 @@ mod native_tests {
1,
"resources",
// we need the accuracy
Some(vec![7, 8]),
Some(vec![4]),
1,
false,
);

View File

@@ -78,14 +78,20 @@ def compare_outputs(zk_output, onnx_output):
zip_object = zip(np.array(zk_output).flatten(),
np.array(onnx_output).flatten())
for list1_i, list2_i in zip_object:
for (i, (list1_i, list2_i)) in enumerate(zip_object):
if list1_i == 0.0 and list2_i == 0.0:
res.append(0)
else:
diff = list1_i - list2_i
res.append(100 * (diff) / (list2_i))
# iterate and print the diffs if they are greater than 0.0
if abs(diff) > 0.0:
print("------- index: ", i)
print("------- diff: ", diff)
print("------- zk_output: ", list1_i)
print("------- onnx_output: ", list2_i)
print("res: ", res)
return np.mean(np.abs(res))