Compare commits

...

4 Commits

Author SHA1 Message Date
Alexander Camuto
f44ffdb1cf Update utilities.rs 2025-02-04 14:31:05 +00:00
Alexander Camuto
c6e541de22 Update ops.rs 2025-02-04 14:05:55 +00:00
Alexander Camuto
71d6678a0b files 2025-02-04 14:00:25 +00:00
Alexander Camuto
145c7a19f4 fix: use onnx convention when integer dividing 2025-02-04 13:57:35 +00:00
6 changed files with 68 additions and 12 deletions

View File

@@ -0,0 +1,42 @@
from torch import nn
import torch
import json
import numpy as np
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
def forward(self, x):
return x // 3
circuit = MyModel()
x = torch.randint(0, 10, (1, 2, 2, 8))
out = circuit(x)
print(x)
print(out)
print(x/3)
torch.onnx.export(circuit, x, "network.onnx",
export_params=True, # store the trained parameter weights inside the model file
opset_version=17, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names=['input'], # the model's input names
output_names=['output'], # the model's output names
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
'output': {0: 'batch_size'}})
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
data = dict(
input_data=[d1],
)
# Serialize data into file:
json.dump(data, open("input.json", 'w'))

View File

@@ -0,0 +1 @@
{"input_data": [[3, 4, 0, 9, 2, 6, 2, 5, 1, 5, 3, 5, 5, 7, 0, 2, 6, 1, 4, 4, 1, 9, 7, 7, 5, 8, 2, 0, 1, 5, 9, 8]]}

Binary file not shown.

View File

@@ -157,7 +157,9 @@ pub(crate) fn div<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
// implicitly check if the prover provided output is within range
let claimed_output = identity(config, region, &[claimed_output], true)?;
// check if x is too large only if the decomp would support overflow in the previous op
if (IntegerRep::MAX).abs() < ((region.base() as i128).pow(region.legs() as u32)) - 1 {
if F::from_u128(IntegerRep::MAX as u128)
< F::from_u128(region.base() as u128).pow([region.legs() as u64]) - F::ONE
{
// here we decompose and extract the sign of the input
let sign = sign(config, region, &[claimed_output.clone()])?;
@@ -254,7 +256,9 @@ pub(crate) fn recip<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
)?;
// check if x is too large only if the decomp would support overflow in the previous op
if (IntegerRep::MAX).abs() < ((region.base() as i128).pow(region.legs() as u32)) - 1 {
if F::from_u128(IntegerRep::MAX as u128)
< F::from_u128(region.base() as u128).pow([region.legs() as u64]) - F::ONE
{
// here we decompose and extract the sign of the input
let sign = sign(config, region, &[masked_output.clone()])?;
let abs_value = pairwise(
@@ -2652,9 +2656,9 @@ pub fn mean_of_squares_axes<F: PrimeField + TensorType + PartialOrd + std::hash:
let squared = pow(config, region, values, 2)?;
let sum_squared = sum_axes(config, region, &[squared], axes)?;
let dividand: usize = values[0].len() / sum_squared.len();
let dividend: usize = values[0].len() / sum_squared.len();
let mean_squared = div(config, region, &[sum_squared], F::from(dividand as u64))?;
let mean_squared = div(config, region, &[sum_squared], F::from(dividend as u64))?;
Ok(mean_squared)
}

View File

@@ -274,11 +274,9 @@ pub fn new_op_from_onnx(
symbol_values: &SymbolValues,
run_args: &crate::RunArgs,
) -> Result<(SupportedOp, Vec<usize>), GraphError> {
use std::f64::consts::E;
use tract_onnx::tract_core::ops::array::Trilu;
use crate::circuit::InputType;
use std::f64::consts::E;
use tract_onnx::tract_core::ops::array::Trilu;
let input_scales = inputs
.iter()
@@ -1274,9 +1272,19 @@ pub fn new_op_from_onnx(
// get the non constant index
let denom = c.raw_values[0];
SupportedOp::Hybrid(HybridOp::Div {
let op = SupportedOp::Hybrid(HybridOp::Div {
denom: denom.into(),
})
});
// if the input is scale 0 we re up to the max scale
if input_scales[0] == 0 {
SupportedOp::Rescaled(Rescaled {
inner: Box::new(op),
scale: vec![(0, scale_to_multiplier(scales.get_max()) as u128)],
})
} else {
op
}
} else {
return Err(GraphError::MisformedParams(
"only support non zero divisors of size 1".to_string(),

View File

@@ -206,7 +206,7 @@ mod native_tests {
"1l_tiny_div",
];
const TESTS: [&str; 98] = [
const TESTS: [&str; 99] = [
"1l_mlp", //0
"1l_slice", //1
"1l_concat", //2
@@ -309,6 +309,7 @@ mod native_tests {
"log", // 95
"exp", // 96
"general_exp", // 97
"integer_div", // 98
];
const WASM_TESTS: [&str; 46] = [
@@ -547,7 +548,7 @@ mod native_tests {
}
});
seq!(N in 0..=97 {
seq!(N in 0..=98 {
#(#[test_case(TESTS[N])])*
#[ignore]