mirror of
https://github.com/zkonduit/ezkl.git
synced 2026-01-13 08:17:57 -05:00
Compare commits
4 Commits
ac/make-da
...
ac/floor-r
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f44ffdb1cf | ||
|
|
c6e541de22 | ||
|
|
71d6678a0b | ||
|
|
145c7a19f4 |
42
examples/onnx/integer_div/gen.py
Normal file
42
examples/onnx/integer_div/gen.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from torch import nn
|
||||
import torch
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
|
||||
class MyModel(nn.Module):
|
||||
def __init__(self):
|
||||
super(MyModel, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x // 3
|
||||
|
||||
|
||||
circuit = MyModel()
|
||||
|
||||
x = torch.randint(0, 10, (1, 2, 2, 8))
|
||||
|
||||
out = circuit(x)
|
||||
|
||||
print(x)
|
||||
print(out)
|
||||
print(x/3)
|
||||
|
||||
torch.onnx.export(circuit, x, "network.onnx",
|
||||
export_params=True, # store the trained parameter weights inside the model file
|
||||
opset_version=17, # the ONNX version to export the model to
|
||||
do_constant_folding=True, # whether to execute constant folding for optimization
|
||||
input_names=['input'], # the model's input names
|
||||
output_names=['output'], # the model's output names
|
||||
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
|
||||
'output': {0: 'batch_size'}})
|
||||
|
||||
|
||||
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
|
||||
|
||||
data = dict(
|
||||
input_data=[d1],
|
||||
)
|
||||
|
||||
# Serialize data into file:
|
||||
json.dump(data, open("input.json", 'w'))
|
||||
1
examples/onnx/integer_div/input.json
Normal file
1
examples/onnx/integer_div/input.json
Normal file
@@ -0,0 +1 @@
|
||||
{"input_data": [[3, 4, 0, 9, 2, 6, 2, 5, 1, 5, 3, 5, 5, 7, 0, 2, 6, 1, 4, 4, 1, 9, 7, 7, 5, 8, 2, 0, 1, 5, 9, 8]]}
|
||||
BIN
examples/onnx/integer_div/network.onnx
Normal file
BIN
examples/onnx/integer_div/network.onnx
Normal file
Binary file not shown.
@@ -157,7 +157,9 @@ pub(crate) fn div<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
// implicitly check if the prover provided output is within range
|
||||
let claimed_output = identity(config, region, &[claimed_output], true)?;
|
||||
// check if x is too large only if the decomp would support overflow in the previous op
|
||||
if (IntegerRep::MAX).abs() < ((region.base() as i128).pow(region.legs() as u32)) - 1 {
|
||||
if F::from_u128(IntegerRep::MAX as u128)
|
||||
< F::from_u128(region.base() as u128).pow([region.legs() as u64]) - F::ONE
|
||||
{
|
||||
// here we decompose and extract the sign of the input
|
||||
let sign = sign(config, region, &[claimed_output.clone()])?;
|
||||
|
||||
@@ -254,7 +256,9 @@ pub(crate) fn recip<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
)?;
|
||||
|
||||
// check if x is too large only if the decomp would support overflow in the previous op
|
||||
if (IntegerRep::MAX).abs() < ((region.base() as i128).pow(region.legs() as u32)) - 1 {
|
||||
if F::from_u128(IntegerRep::MAX as u128)
|
||||
< F::from_u128(region.base() as u128).pow([region.legs() as u64]) - F::ONE
|
||||
{
|
||||
// here we decompose and extract the sign of the input
|
||||
let sign = sign(config, region, &[masked_output.clone()])?;
|
||||
let abs_value = pairwise(
|
||||
@@ -2652,9 +2656,9 @@ pub fn mean_of_squares_axes<F: PrimeField + TensorType + PartialOrd + std::hash:
|
||||
let squared = pow(config, region, values, 2)?;
|
||||
let sum_squared = sum_axes(config, region, &[squared], axes)?;
|
||||
|
||||
let dividand: usize = values[0].len() / sum_squared.len();
|
||||
let dividend: usize = values[0].len() / sum_squared.len();
|
||||
|
||||
let mean_squared = div(config, region, &[sum_squared], F::from(dividand as u64))?;
|
||||
let mean_squared = div(config, region, &[sum_squared], F::from(dividend as u64))?;
|
||||
Ok(mean_squared)
|
||||
}
|
||||
|
||||
|
||||
@@ -274,11 +274,9 @@ pub fn new_op_from_onnx(
|
||||
symbol_values: &SymbolValues,
|
||||
run_args: &crate::RunArgs,
|
||||
) -> Result<(SupportedOp, Vec<usize>), GraphError> {
|
||||
use std::f64::consts::E;
|
||||
|
||||
use tract_onnx::tract_core::ops::array::Trilu;
|
||||
|
||||
use crate::circuit::InputType;
|
||||
use std::f64::consts::E;
|
||||
use tract_onnx::tract_core::ops::array::Trilu;
|
||||
|
||||
let input_scales = inputs
|
||||
.iter()
|
||||
@@ -1274,9 +1272,19 @@ pub fn new_op_from_onnx(
|
||||
// get the non constant index
|
||||
let denom = c.raw_values[0];
|
||||
|
||||
SupportedOp::Hybrid(HybridOp::Div {
|
||||
let op = SupportedOp::Hybrid(HybridOp::Div {
|
||||
denom: denom.into(),
|
||||
})
|
||||
});
|
||||
|
||||
// if the input is scale 0 we re up to the max scale
|
||||
if input_scales[0] == 0 {
|
||||
SupportedOp::Rescaled(Rescaled {
|
||||
inner: Box::new(op),
|
||||
scale: vec![(0, scale_to_multiplier(scales.get_max()) as u128)],
|
||||
})
|
||||
} else {
|
||||
op
|
||||
}
|
||||
} else {
|
||||
return Err(GraphError::MisformedParams(
|
||||
"only support non zero divisors of size 1".to_string(),
|
||||
|
||||
@@ -206,7 +206,7 @@ mod native_tests {
|
||||
"1l_tiny_div",
|
||||
];
|
||||
|
||||
const TESTS: [&str; 98] = [
|
||||
const TESTS: [&str; 99] = [
|
||||
"1l_mlp", //0
|
||||
"1l_slice", //1
|
||||
"1l_concat", //2
|
||||
@@ -309,6 +309,7 @@ mod native_tests {
|
||||
"log", // 95
|
||||
"exp", // 96
|
||||
"general_exp", // 97
|
||||
"integer_div", // 98
|
||||
];
|
||||
|
||||
const WASM_TESTS: [&str; 46] = [
|
||||
@@ -547,7 +548,7 @@ mod native_tests {
|
||||
}
|
||||
});
|
||||
|
||||
seq!(N in 0..=97 {
|
||||
seq!(N in 0..=98 {
|
||||
|
||||
#(#[test_case(TESTS[N])])*
|
||||
#[ignore]
|
||||
|
||||
Reference in New Issue
Block a user