Compare commits

...

4 Commits

Author SHA1 Message Date
dante
5e169bdd17 chore: update tract to 0.21.8-pre (#878) 2024-12-03 16:52:03 -05:00
dante
64cbcb3f7e chore: explicitly compile div op (#876) 2024-11-28 17:14:53 +09:00
dante
ee17f0ff9a chore: generalize the exp to other bases (#875) 2024-11-26 09:31:12 +09:00
Jseam
ee55e7dc19 fix: upgrade run-on-arch (#874) 2024-11-24 14:30:42 +09:00
20 changed files with 308 additions and 147 deletions

View File

@@ -168,7 +168,7 @@ jobs:
name: wheels
path: dist
# TODO: There's a problem with the maturin-action toolchain for arm arch leading to failed builds
# There's a problem with the maturin-action toolchain for arm arch leading to failed builds
# linux-cross:
# runs-on: ubuntu-latest
# strategy:
@@ -306,7 +306,7 @@ jobs:
manylinux: musllinux_1_2
args: --release --out dist --features python-bindings
- uses: uraimo/run-on-arch-action@v2.5.0
- uses: uraimo/run-on-arch-action@v2.8.1
name: Install built wheel
with:
arch: ${{ matrix.platform.arch }}

View File

@@ -207,23 +207,6 @@ jobs:
# AR=/opt/homebrew/opt/llvm/bin/llvm-ar CC=/opt/homebrew/opt/llvm/bin/clang wasm-pack test --firefox --headless -- -Z build-std="panic_abort,std" --features web
run: wasm-pack test --chrome --headless -- -Z build-std="panic_abort,std" --features web
tutorial:
runs-on: ubuntu-latest
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
with:
crate: cargo-nextest
locked: true
- name: Circuit Render
run: cargo nextest run --release --verbose tests::tutorial_
mock-proving-tests:
runs-on: non-gpu
needs: [build, library-tests, docs, python-tests, python-integration-tests]
@@ -494,23 +477,23 @@ jobs:
- name: Mock aggr tests (KZG)
run: cargo nextest run --release --verbose tests_aggr::kzg_aggr_mock_prove_and_verify_ --test-threads 8
prove-and-verify-aggr-tests-gpu:
runs-on: GPU
env:
ENABLE_ICICLE_GPU: true
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
with:
crate: cargo-nextest
locked: true
- name: KZG )tests
run: cargo nextest run --verbose tests_aggr::kzg_aggr_prove_and_verify_ --features icicle --test-threads 1 -- --include-ignored
# prove-and-verify-aggr-tests-gpu:
# runs-on: GPU
# env:
# ENABLE_ICICLE_GPU: true
# steps:
# - uses: actions/checkout@v4
# - uses: actions-rs/toolchain@v1
# with:
# toolchain: nightly-2024-07-18
# override: true
# components: rustfmt, clippy
# - uses: baptiste0928/cargo-install@v1
# with:
# crate: cargo-nextest
# locked: true
# - name: KZG tests
# run: cargo nextest run --verbose tests_aggr::kzg_aggr_prove_and_verify_ --features icicle --test-threads 1 -- --include-ignored
prove-and-verify-aggr-tests:
runs-on: large-self-hosted
@@ -614,8 +597,6 @@ jobs:
run: python -m venv .env --clear; source .env/bin/activate; pip install -r requirements.txt;
- name: Build python ezkl
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --release
- name: Div rebase
run: source .env/bin/activate; cargo nextest run --release --verbose tests::accuracy_measurement_div_rebase_
- name: Public inputs
run: source .env/bin/activate; cargo nextest run --release --verbose tests::accuracy_measurement_public_inputs_
- name: fixed params

51
Cargo.lock generated
View File

@@ -598,6 +598,12 @@ version = "1.0.81"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0952808a6c2afd1aa8947271f3a60f1a6763c7b912d210184c5149b5cf147247"
[[package]]
name = "anymap"
version = "0.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "33954243bd79057c2de7338850b85983a44588021f8a5fee574a8888c6de4344"
[[package]]
name = "anymap2"
version = "0.13.0"
@@ -3322,14 +3328,16 @@ dependencies = [
[[package]]
name = "ndarray"
version = "0.15.6"
version = "0.16.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "adb12d4e967ec485a5f71c6311fe28158e9d6f4bc4a447b474184d0f91a8fa32"
checksum = "882ed72dce9365842bf196bdeedf5055305f11fc8c03dee7bb0194a6cad34841"
dependencies = [
"matrixmultiply",
"num-complex",
"num-integer",
"num-traits",
"portable-atomic",
"portable-atomic-util",
"rawpointer",
]
@@ -3857,6 +3865,15 @@ version = "1.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7170ef9988bc169ba16dd36a7fa041e5c4cbeb6a35b76d4c03daded371eae7c0"
[[package]]
name = "portable-atomic-util"
version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d8a2f0d8d040d7848a709caf78912debcc3f33ee4b3cac47d73d1e1069e83507"
dependencies = [
"portable-atomic",
]
[[package]]
name = "poseidon"
version = "0.2.0"
@@ -5672,10 +5689,11 @@ dependencies = [
[[package]]
name = "tract-core"
version = "0.21.6-pre"
source = "git+https://github.com/sonos/tract/?rev=40c64319291184814d9fea5fdf4fa16f5a4f7116#40c64319291184814d9fea5fdf4fa16f5a4f7116"
version = "0.21.8-pre"
source = "git+https://github.com/sonos/tract/?rev=37132e0397d0a73e5bd3a8615d932dabe44f6736#37132e0397d0a73e5bd3a8615d932dabe44f6736"
dependencies = [
"anyhow",
"anymap",
"bit-set",
"derive-new",
"downcast-rs",
@@ -5696,8 +5714,8 @@ dependencies = [
[[package]]
name = "tract-data"
version = "0.21.6-pre"
source = "git+https://github.com/sonos/tract/?rev=40c64319291184814d9fea5fdf4fa16f5a4f7116#40c64319291184814d9fea5fdf4fa16f5a4f7116"
version = "0.21.8-pre"
source = "git+https://github.com/sonos/tract/?rev=37132e0397d0a73e5bd3a8615d932dabe44f6736#37132e0397d0a73e5bd3a8615d932dabe44f6736"
dependencies = [
"anyhow",
"downcast-rs",
@@ -5711,6 +5729,7 @@ dependencies = [
"nom",
"num-integer",
"num-traits",
"parking_lot",
"scan_fmt",
"smallvec",
"string-interner",
@@ -5718,8 +5737,8 @@ dependencies = [
[[package]]
name = "tract-hir"
version = "0.21.6-pre"
source = "git+https://github.com/sonos/tract/?rev=40c64319291184814d9fea5fdf4fa16f5a4f7116#40c64319291184814d9fea5fdf4fa16f5a4f7116"
version = "0.21.8-pre"
source = "git+https://github.com/sonos/tract/?rev=37132e0397d0a73e5bd3a8615d932dabe44f6736#37132e0397d0a73e5bd3a8615d932dabe44f6736"
dependencies = [
"derive-new",
"log",
@@ -5728,8 +5747,8 @@ dependencies = [
[[package]]
name = "tract-linalg"
version = "0.21.6-pre"
source = "git+https://github.com/sonos/tract/?rev=40c64319291184814d9fea5fdf4fa16f5a4f7116#40c64319291184814d9fea5fdf4fa16f5a4f7116"
version = "0.21.8-pre"
source = "git+https://github.com/sonos/tract/?rev=37132e0397d0a73e5bd3a8615d932dabe44f6736#37132e0397d0a73e5bd3a8615d932dabe44f6736"
dependencies = [
"byteorder",
"cc",
@@ -5755,8 +5774,8 @@ dependencies = [
[[package]]
name = "tract-nnef"
version = "0.21.6-pre"
source = "git+https://github.com/sonos/tract/?rev=40c64319291184814d9fea5fdf4fa16f5a4f7116#40c64319291184814d9fea5fdf4fa16f5a4f7116"
version = "0.21.8-pre"
source = "git+https://github.com/sonos/tract/?rev=37132e0397d0a73e5bd3a8615d932dabe44f6736#37132e0397d0a73e5bd3a8615d932dabe44f6736"
dependencies = [
"byteorder",
"flate2",
@@ -5769,8 +5788,8 @@ dependencies = [
[[package]]
name = "tract-onnx"
version = "0.21.6-pre"
source = "git+https://github.com/sonos/tract/?rev=40c64319291184814d9fea5fdf4fa16f5a4f7116#40c64319291184814d9fea5fdf4fa16f5a4f7116"
version = "0.21.8-pre"
source = "git+https://github.com/sonos/tract/?rev=37132e0397d0a73e5bd3a8615d932dabe44f6736#37132e0397d0a73e5bd3a8615d932dabe44f6736"
dependencies = [
"bytes",
"derive-new",
@@ -5786,8 +5805,8 @@ dependencies = [
[[package]]
name = "tract-onnx-opl"
version = "0.21.6-pre"
source = "git+https://github.com/sonos/tract/?rev=40c64319291184814d9fea5fdf4fa16f5a4f7116#40c64319291184814d9fea5fdf4fa16f5a4f7116"
version = "0.21.8-pre"
source = "git+https://github.com/sonos/tract/?rev=37132e0397d0a73e5bd3a8615d932dabe44f6736#37132e0397d0a73e5bd3a8615d932dabe44f6736"
dependencies = [
"getrandom",
"log",

View File

@@ -89,7 +89,7 @@ pyo3-asyncio = { git = "https://github.com/jopemachine/pyo3-asyncio/", branch =
"tokio-runtime",
], default-features = false, optional = true }
pyo3-log = { version = "0.10.0", default-features = false, optional = true }
tract-onnx = { git = "https://github.com/sonos/tract/", rev = "40c64319291184814d9fea5fdf4fa16f5a4f7116", default-features = false, optional = true }
tract-onnx = { git = "https://github.com/sonos/tract/", rev = "37132e0397d0a73e5bd3a8615d932dabe44f6736", default-features = false, optional = true }
tabled = { version = "0.12.0", optional = true }
metal = { git = "https://github.com/gfx-rs/metal-rs", optional = true }
objc = { version = "0.2.4", optional = true }

42
examples/onnx/exp/gen.py Normal file
View File

@@ -0,0 +1,42 @@
from torch import nn
import torch
import json
import numpy as np
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
def forward(self, x):
m = torch.exp(x)
return m
circuit = MyModel()
x = torch.empty(1, 8).uniform_(0, 1)
out = circuit(x)
print(out)
torch.onnx.export(circuit, x, "network.onnx",
export_params=True, # store the trained parameter weights inside the model file
opset_version=17, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names=['input'], # the model's input names
output_names=['output'], # the model's output names
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
'output': {0: 'batch_size'}})
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
data = dict(
input_data=[d1],
)
# Serialize data into file:
json.dump(data, open("input.json", 'w'))

View File

@@ -0,0 +1 @@
{"input_data": [[0.5801457762718201, 0.6019012331962585, 0.8695418238639832, 0.17170941829681396, 0.500616729259491, 0.353726327419281, 0.6726185083389282, 0.5936906337738037]]}

View File

@@ -0,0 +1,14 @@
pytorch2.2.2:o

inputoutput/Exp"Exp
main_graphZ!
input


batch_size
b"
output


batch_size
B

View File

@@ -0,0 +1,41 @@
from torch import nn
import torch
import json
import numpy as np
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
def forward(self, x):
m = 10**x
return m
circuit = MyModel()
x = torch.empty(1, 8).uniform_(0, 1)
out = circuit(x)
print(out)
torch.onnx.export(circuit, x, "network.onnx",
export_params=True, # store the trained parameter weights inside the model file
opset_version=17, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names=['input'], # the model's input names
output_names=['output'], # the model's output names
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
'output': {0: 'batch_size'}})
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
data = dict(
input_data=[d1],
)
# Serialize data into file:
json.dump(data, open("input.json", 'w'))

View File

@@ -0,0 +1 @@
{"input_data": [[0.9837989807128906, 0.026381194591522217, 0.3403851389884949, 0.14531707763671875, 0.24652725458145142, 0.7945117354393005, 0.4076554775238037, 0.23064672946929932]]}

Binary file not shown.

View File

@@ -1,75 +1,52 @@
import torch
import torch.nn as nn
import sys
from torch import nn
import json
sys.path.append("..")
class Model(nn.Module):
"""
Just one Linear layer
"""
def __init__(self, configs):
super(Model, self).__init__()
self.seq_len = configs.seq_len
self.pred_len = configs.pred_len
# Use this line if you want to visualize the weights
# self.Linear.weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len]))
self.channels = configs.enc_in
self.individual = configs.individual
if self.individual:
self.Linear = nn.ModuleList()
for i in range(self.channels):
self.Linear.append(nn.Linear(self.seq_len,self.pred_len))
else:
self.Linear = nn.Linear(self.seq_len, self.pred_len)
def forward(self, x):
# x: [Batch, Input length, Channel]
if self.individual:
output = torch.zeros([x.size(0),self.pred_len,x.size(2)],dtype=x.dtype).to(x.device)
for i in range(self.channels):
output[:,:,i] = self.Linear[i](x[:,:,i])
x = output
else:
x = self.Linear(x.permute(0,2,1)).permute(0,2,1)
return x # [Batch, Output length, Channel]
class Configs:
def __init__(self, seq_len, pred_len, enc_in=321, individual=True):
self.seq_len = seq_len
self.pred_len = pred_len
self.enc_in = enc_in
self.individual = individual
model = 'Linear'
seq_len = 10
pred_len = 4
enc_in = 3
configs = Configs(seq_len, pred_len, enc_in, True)
circuit = Model(configs)
x = torch.randn(1, seq_len, pred_len)
import numpy as np
import tf2onnx
torch.onnx.export(circuit, x, "network.onnx",
export_params=True, # store the trained parameter weights inside the model file
opset_version=15, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
# the model's input names
input_names=['input'],
output_names=['output'], # the model's output names
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
'output': {0: 'batch_size'}})
import tensorflow as tf
from tensorflow.keras.layers import *
from tensorflow.keras.models import Model
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
# gather_nd in tf then export to onnx
x = in1 = Input((4, 1), dtype=tf.int32)
w = in2 = Input((4, ), dtype=tf.int32)
class MyLayer(Layer):
def call(self, x, w):
shape = tf.constant([8])
return tf.scatter_nd(x, w, shape)
x = MyLayer()(x, w)
tm = Model((in1, in2), x)
tm.summary()
tm.compile(optimizer='adam', loss='mse')
shape = [1, 4, 1]
index_shape = [1, 4]
# After training, export to onnx (network.onnx) and create a data file (input.json)
x = np.random.randint(0, 4, shape)
# w = random int tensor
w = np.random.randint(0, 4, index_shape)
spec = tf.TensorSpec(shape, tf.int32, name='input_0')
index_spec = tf.TensorSpec(index_shape, tf.int32, name='input_1')
model_path = "network.onnx"
tf2onnx.convert.from_keras(tm, input_signature=[spec, index_spec], inputs_as_nchw=['input_0', 'input_1'], opset=12, output_path=model_path)
d = x.reshape([-1]).tolist()
d1 = w.reshape([-1]).tolist()
data = dict(
input_data=[d1],
input_data=[d, d1],
)
# Serialize data into file:

View File

@@ -1 +1,16 @@
{"input_data": [[0.1874287724494934, 1.0498261451721191, 0.22384068369865417, 1.048445224761963, -0.5670360326766968, -0.38653188943862915, 0.12878702580928802, -2.3675858974456787, 0.5800458192825317, -0.43653929233551025, -0.2511898875236511, 0.3324051797389984, 0.27960312366485596, 0.4763695001602173, 0.3796705901622772, 1.1334782838821411, -0.87981778383255, -1.2451434135437012, 0.7672272324562073, -0.24404007196426392, -0.6875824928283691, 0.3619358539581299, -0.10131897777318954, 0.7169521450996399, 1.6585893630981445, -0.5451845526695251, 0.429487019777298, 0.7426952123641968, -0.2543637454509735, 0.06546942889690399, 0.7939824461936951, 0.1579471379518509, -0.043604474514722824, -0.8621711730957031, -0.5344759821891785, -0.05880478024482727, -0.17351101338863373, 0.5095029473304749, -0.7864817976951599, -0.449171245098114]]}
{
"input_data": [
[
0,
1,
2,
3
],
[
1,
0,
2,
1
]
]
}

View File

@@ -1,5 +1,6 @@
use std::{
collections::{HashMap, HashSet},
f64::consts::E,
ops::Range,
};
@@ -1686,6 +1687,7 @@ pub(crate) fn linearize_nd_index<F: PrimeField + TensorType + PartialOrd + std::
Ok(output.into())
}
// assumes unique values in fullset
pub(crate) fn get_missing_set_elements<
F: PrimeField + TensorType + PartialOrd + std::hash::Hash,
>(
@@ -1698,6 +1700,8 @@ pub(crate) fn get_missing_set_elements<
let set_len = fullset.len();
input.flatten();
// while fullset is less than len of input concat
let is_assigned = !input.any_unknowns()? && !fullset.any_unknowns()?;
let mut claimed_output: ValTensor<F> = if is_assigned {
@@ -5624,7 +5628,10 @@ pub fn softmax<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config,
region,
&[sub],
&LookupOp::Exp { scale: input_scale },
&LookupOp::Exp {
scale: input_scale,
base: E.into(),
},
)?;
percent(config, region, &[ex.clone()], input_scale, output_scale)

View File

@@ -19,7 +19,7 @@ pub enum LookupOp {
PowersOfTwo { scale: utils::F32 },
Ln { scale: utils::F32 },
Sigmoid { scale: utils::F32 },
Exp { scale: utils::F32 },
Exp { scale: utils::F32, base: utils::F32 },
Cos { scale: utils::F32 },
ACos { scale: utils::F32 },
Cosh { scale: utils::F32 },
@@ -55,7 +55,7 @@ impl LookupOp {
LookupOp::Div { denom } => format!("div_{}", denom),
LookupOp::Sigmoid { scale } => format!("sigmoid_{}", scale),
LookupOp::Erf { scale } => format!("erf_{}", scale),
LookupOp::Exp { scale } => format!("exp_{}", scale),
LookupOp::Exp { scale, base } => format!("exp_{}_{}", scale, base),
LookupOp::Cos { scale } => format!("cos_{}", scale),
LookupOp::ACos { scale } => format!("acos_{}", scale),
LookupOp::Cosh { scale } => format!("cosh_{}", scale),
@@ -99,9 +99,9 @@ impl LookupOp {
LookupOp::Erf { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::erffunc(&x, scale.into()))
}
LookupOp::Exp { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::exp(&x, scale.into()))
}
LookupOp::Exp { scale, base } => Ok::<_, TensorError>(
tensor::ops::nonlinearities::exp(&x, scale.into(), base.into()),
),
LookupOp::Cos { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::cos(&x, scale.into()))
}
@@ -165,7 +165,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Lookup
LookupOp::Div { denom, .. } => format!("DIV(denom={})", denom),
LookupOp::Sigmoid { scale } => format!("SIGMOID(scale={})", scale),
LookupOp::Erf { scale } => format!("ERF(scale={})", scale),
LookupOp::Exp { scale } => format!("EXP(scale={})", scale),
LookupOp::Exp { scale, base } => format!("EXP(scale={}, base={})", scale, base),
LookupOp::Tan { scale } => format!("TAN(scale={})", scale),
LookupOp::ATan { scale } => format!("ATAN(scale={})", scale),
LookupOp::Tanh { scale } => format!("TANH(scale={})", scale),

View File

@@ -656,7 +656,7 @@ impl Model {
let mut symbol_values = SymbolValues::default();
for (symbol, value) in run_args.variables.iter() {
let symbol = model.symbol_table.sym(symbol);
let symbol = model.symbols.sym(symbol);
symbol_values = symbol_values.with(&symbol, *value as i64);
debug!("set {} to {}", symbol, value);
}
@@ -1199,9 +1199,9 @@ impl Model {
// Then number of columns in the circuits
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
region.debug_report();
debug!("input indices: {:?}", node.inputs());
debug!("output scales: {:?}", node.out_scales());
debug!(
trace!("input indices: {:?}", node.inputs());
trace!("output scales: {:?}", node.out_scales());
trace!(
"input scales: {:?}",
node.inputs()
.iter()
@@ -1220,8 +1220,8 @@ impl Model {
// we re-assign inputs, always from the 0 outlet
vec![results.get(idx).ok_or(GraphError::MissingResults)?[0].clone()]
};
debug!("output dims: {:?}", node.out_dims());
debug!(
trace!("output dims: {:?}", node.out_dims());
trace!(
"input dims {:?}",
values.iter().map(|v| v.dims()).collect_vec()
);

View File

@@ -279,6 +279,8 @@ pub fn new_op_from_onnx(
symbol_values: &SymbolValues,
run_args: &crate::RunArgs,
) -> Result<(SupportedOp, Vec<usize>), GraphError> {
use std::f64::consts::E;
use tract_onnx::tract_core::ops::array::Trilu;
use crate::circuit::InputType;
@@ -855,6 +857,7 @@ pub fn new_op_from_onnx(
}
"Exp" => SupportedOp::Nonlinear(LookupOp::Exp {
scale: scale_to_multiplier(input_scales[0]).into(),
base: E.into(),
}),
"Ln" => {
if run_args.bounded_log_lookup {
@@ -1004,21 +1007,21 @@ pub fn new_op_from_onnx(
op
}
"Iff" => SupportedOp::Linear(PolyOp::Iff),
"Less" => {
"<" => {
if inputs.len() == 2 {
SupportedOp::Hybrid(HybridOp::Less)
} else {
return Err(GraphError::InvalidDims(idx, "less".to_string()));
}
}
"LessEqual" => {
"<=" => {
if inputs.len() == 2 {
SupportedOp::Hybrid(HybridOp::LessEqual)
} else {
return Err(GraphError::InvalidDims(idx, "less equal".to_string()));
}
}
"Greater" => {
">" => {
// Extract the slope layer hyperparams
if inputs.len() == 2 {
SupportedOp::Hybrid(HybridOp::Greater)
@@ -1026,7 +1029,7 @@ pub fn new_op_from_onnx(
return Err(GraphError::InvalidDims(idx, "greater".to_string()));
}
}
"GreaterEqual" => {
">=" => {
// Extract the slope layer hyperparams
if inputs.len() == 2 {
SupportedOp::Hybrid(HybridOp::GreaterEqual)
@@ -1134,7 +1137,57 @@ pub fn new_op_from_onnx(
})
}
} else {
unimplemented!("only support constant pow for now")
if let Some(c) = inputs[0].opkind().get_mutable_constant() {
inputs[0].decrement_use();
deleted_indices.push(0);
if c.raw_values.len() > 1 {
unimplemented!("only support scalar base")
}
let base = c.raw_values[0];
SupportedOp::Nonlinear(LookupOp::Exp {
scale: scale_to_multiplier(input_scales[1]).into(),
base: base.into(),
})
} else {
unimplemented!("only support constant base or pow for now")
}
}
}
"Div" => {
let const_idx = inputs
.iter()
.enumerate()
.filter(|(_, n)| n.is_constant())
.map(|(i, _)| i)
.collect::<Vec<_>>();
if const_idx.len() > 1 {
return Err(GraphError::InvalidDims(idx, "div".to_string()));
}
let const_idx = const_idx[0];
if const_idx != 1 {
unimplemented!("only support div with constant as second input")
}
if let Some(c) = inputs[const_idx].opkind().get_mutable_constant() {
if c.raw_values.len() == 1 && c.raw_values[0] != 0. {
inputs[const_idx].decrement_use();
deleted_indices.push(const_idx);
// get the non constant index
let denom = c.raw_values[0];
SupportedOp::Hybrid(HybridOp::Div {
denom: denom.into(),
})
} else {
unimplemented!("only support non zero divisors of size 1")
}
} else {
unimplemented!("only support div with constant as second input")
}
}
"Cube" => SupportedOp::Linear(PolyOp::Pow(3)),
@@ -1197,7 +1250,7 @@ pub fn new_op_from_onnx(
"And" => SupportedOp::Linear(PolyOp::And),
"Or" => SupportedOp::Linear(PolyOp::Or),
"Xor" => SupportedOp::Linear(PolyOp::Xor),
"Equals" => SupportedOp::Hybrid(HybridOp::Equals),
"==" => SupportedOp::Hybrid(HybridOp::Equals),
"Deconv" => {
let deconv_node: &Deconv = match node.op().downcast_ref::<Deconv>() {
Some(b) => b,

View File

@@ -1109,6 +1109,13 @@ impl<T: Clone + TensorType> Tensor<T> {
///
/// ```
pub fn expand(&self, shape: &[usize]) -> Result<Self, TensorError> {
// if both have length 1 then we can just return the tensor
if self.dims().iter().product::<usize>() == 1 && shape.iter().product::<usize>() == 1 {
let mut output = self.clone();
output.reshape(shape)?;
return Ok(output);
}
if self.dims().len() > shape.len() {
return Err(TensorError::DimError(format!(
"Cannot expand {:?} to the smaller shape {:?}",

View File

@@ -1050,6 +1050,7 @@ pub fn scatter_nd<T: TensorType + Send + Sync>(
let slice = coord.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
let index_val = index.get_slice(&slice)?;
let index_slice = index_val.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
let src_val = src.get_slice(&slice)?;
output.set_slice(&index_slice, &src_val)?;
Ok::<_, TensorError>(())
@@ -1664,7 +1665,7 @@ pub mod nonlinearities {
/// Some(&[2, 15, 2, 1, 1, 0]),
/// &[2, 3],
/// ).unwrap();
/// let result = exp(&x, 1.0);
/// let result = exp(&x, 1.0, std::f64::consts::E);
/// let expected = Tensor::<IntegerRep>::new(Some(&[7, 3269017, 7, 3, 3, 1]), &[2, 3]).unwrap();
/// assert_eq!(result, expected);
///
@@ -1673,16 +1674,16 @@ pub mod nonlinearities {
/// Some(&[37, 12, 41]),
/// &[3],
/// ).unwrap();
/// let result = exp(&x, 512.0);
/// let result = exp(&x, 512.0, std::f64::consts::E);
///
/// let expected = Tensor::<IntegerRep>::new(Some(&[550, 524, 555]), &[3]).unwrap();
///
/// assert_eq!(result, expected);
/// ```
pub fn exp(a: &Tensor<IntegerRep>, scale_input: f64) -> Tensor<IntegerRep> {
pub fn exp(a: &Tensor<IntegerRep>, scale_input: f64, base: f64) -> Tensor<IntegerRep> {
a.par_enum_map(|_, a_i| {
let kix = (a_i as f64) / scale_input;
let fout = scale_input * kix.exp();
let fout = scale_input * base.powf(kix);
let rounded = fout.round();
Ok::<_, TensorError>(rounded as IntegerRep)
})

View File

@@ -205,7 +205,7 @@ mod native_tests {
"1l_tiny_div",
];
const TESTS: [&str; 96] = [
const TESTS: [&str; 98] = [
"1l_mlp", //0
"1l_slice",
"1l_concat",
@@ -306,6 +306,8 @@ mod native_tests {
"lenet_5", // 93
"rsqrt", // 94
"log", // 95
"exp", // 96
"general_exp", // 97
];
const WASM_TESTS: [&str; 46] = [
@@ -490,7 +492,7 @@ mod native_tests {
#[cfg(feature="icicle")]
seq!(N in 0..=2 {
#(#[test_case(TESTS_AGGR[N])])*
fn aggr_prove_and_verify_(test: &str) {
fn kzg_aggr_prove_and_verify_(test: &str) {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(test_dir.path().to_str().unwrap(), test);
@@ -544,7 +546,7 @@ mod native_tests {
}
});
seq!(N in 0..=95 {
seq!(N in 0..=97 {
#(#[test_case(TESTS[N])])*
#[ignore]
@@ -634,7 +636,7 @@ mod native_tests {
#(#[test_case(TESTS[N])])*
fn mock_large_batch_public_outputs_(test: &str) {
// currently variable output rank is not supported in ONNX
if test != "gather_nd" && test != "lstm_large" && test != "lstm_medium" {
if test != "gather_nd" && test != "lstm_large" && test != "lstm_medium" && test != "scatter_nd" {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);