mirror of
https://github.com/zkonduit/ezkl.git
synced 2026-01-14 08:48:01 -05:00
Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a9b0eed1b1 | ||
|
|
0610e05a86 | ||
|
|
512eb5ae5d | ||
|
|
9ba6c37e2c |
51
Cargo.lock
generated
51
Cargo.lock
generated
@@ -598,12 +598,6 @@ version = "1.0.81"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0952808a6c2afd1aa8947271f3a60f1a6763c7b912d210184c5149b5cf147247"
|
||||
|
||||
[[package]]
|
||||
name = "anymap"
|
||||
version = "0.12.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "33954243bd79057c2de7338850b85983a44588021f8a5fee574a8888c6de4344"
|
||||
|
||||
[[package]]
|
||||
name = "anymap2"
|
||||
version = "0.13.0"
|
||||
@@ -3328,16 +3322,14 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "ndarray"
|
||||
version = "0.16.1"
|
||||
version = "0.15.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "882ed72dce9365842bf196bdeedf5055305f11fc8c03dee7bb0194a6cad34841"
|
||||
checksum = "adb12d4e967ec485a5f71c6311fe28158e9d6f4bc4a447b474184d0f91a8fa32"
|
||||
dependencies = [
|
||||
"matrixmultiply",
|
||||
"num-complex",
|
||||
"num-integer",
|
||||
"num-traits",
|
||||
"portable-atomic",
|
||||
"portable-atomic-util",
|
||||
"rawpointer",
|
||||
]
|
||||
|
||||
@@ -3865,15 +3857,6 @@ version = "1.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7170ef9988bc169ba16dd36a7fa041e5c4cbeb6a35b76d4c03daded371eae7c0"
|
||||
|
||||
[[package]]
|
||||
name = "portable-atomic-util"
|
||||
version = "0.2.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d8a2f0d8d040d7848a709caf78912debcc3f33ee4b3cac47d73d1e1069e83507"
|
||||
dependencies = [
|
||||
"portable-atomic",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "poseidon"
|
||||
version = "0.2.0"
|
||||
@@ -5689,11 +5672,10 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "tract-core"
|
||||
version = "0.21.8-pre"
|
||||
source = "git+https://github.com/sonos/tract/?rev=37132e0397d0a73e5bd3a8615d932dabe44f6736#37132e0397d0a73e5bd3a8615d932dabe44f6736"
|
||||
version = "0.21.6-pre"
|
||||
source = "git+https://github.com/sonos/tract/?rev=40c64319291184814d9fea5fdf4fa16f5a4f7116#40c64319291184814d9fea5fdf4fa16f5a4f7116"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"anymap",
|
||||
"bit-set",
|
||||
"derive-new",
|
||||
"downcast-rs",
|
||||
@@ -5714,8 +5696,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "tract-data"
|
||||
version = "0.21.8-pre"
|
||||
source = "git+https://github.com/sonos/tract/?rev=37132e0397d0a73e5bd3a8615d932dabe44f6736#37132e0397d0a73e5bd3a8615d932dabe44f6736"
|
||||
version = "0.21.6-pre"
|
||||
source = "git+https://github.com/sonos/tract/?rev=40c64319291184814d9fea5fdf4fa16f5a4f7116#40c64319291184814d9fea5fdf4fa16f5a4f7116"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"downcast-rs",
|
||||
@@ -5729,7 +5711,6 @@ dependencies = [
|
||||
"nom",
|
||||
"num-integer",
|
||||
"num-traits",
|
||||
"parking_lot",
|
||||
"scan_fmt",
|
||||
"smallvec",
|
||||
"string-interner",
|
||||
@@ -5737,8 +5718,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "tract-hir"
|
||||
version = "0.21.8-pre"
|
||||
source = "git+https://github.com/sonos/tract/?rev=37132e0397d0a73e5bd3a8615d932dabe44f6736#37132e0397d0a73e5bd3a8615d932dabe44f6736"
|
||||
version = "0.21.6-pre"
|
||||
source = "git+https://github.com/sonos/tract/?rev=40c64319291184814d9fea5fdf4fa16f5a4f7116#40c64319291184814d9fea5fdf4fa16f5a4f7116"
|
||||
dependencies = [
|
||||
"derive-new",
|
||||
"log",
|
||||
@@ -5747,8 +5728,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "tract-linalg"
|
||||
version = "0.21.8-pre"
|
||||
source = "git+https://github.com/sonos/tract/?rev=37132e0397d0a73e5bd3a8615d932dabe44f6736#37132e0397d0a73e5bd3a8615d932dabe44f6736"
|
||||
version = "0.21.6-pre"
|
||||
source = "git+https://github.com/sonos/tract/?rev=40c64319291184814d9fea5fdf4fa16f5a4f7116#40c64319291184814d9fea5fdf4fa16f5a4f7116"
|
||||
dependencies = [
|
||||
"byteorder",
|
||||
"cc",
|
||||
@@ -5774,8 +5755,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "tract-nnef"
|
||||
version = "0.21.8-pre"
|
||||
source = "git+https://github.com/sonos/tract/?rev=37132e0397d0a73e5bd3a8615d932dabe44f6736#37132e0397d0a73e5bd3a8615d932dabe44f6736"
|
||||
version = "0.21.6-pre"
|
||||
source = "git+https://github.com/sonos/tract/?rev=40c64319291184814d9fea5fdf4fa16f5a4f7116#40c64319291184814d9fea5fdf4fa16f5a4f7116"
|
||||
dependencies = [
|
||||
"byteorder",
|
||||
"flate2",
|
||||
@@ -5788,8 +5769,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "tract-onnx"
|
||||
version = "0.21.8-pre"
|
||||
source = "git+https://github.com/sonos/tract/?rev=37132e0397d0a73e5bd3a8615d932dabe44f6736#37132e0397d0a73e5bd3a8615d932dabe44f6736"
|
||||
version = "0.21.6-pre"
|
||||
source = "git+https://github.com/sonos/tract/?rev=40c64319291184814d9fea5fdf4fa16f5a4f7116#40c64319291184814d9fea5fdf4fa16f5a4f7116"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"derive-new",
|
||||
@@ -5805,8 +5786,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "tract-onnx-opl"
|
||||
version = "0.21.8-pre"
|
||||
source = "git+https://github.com/sonos/tract/?rev=37132e0397d0a73e5bd3a8615d932dabe44f6736#37132e0397d0a73e5bd3a8615d932dabe44f6736"
|
||||
version = "0.21.6-pre"
|
||||
source = "git+https://github.com/sonos/tract/?rev=40c64319291184814d9fea5fdf4fa16f5a4f7116#40c64319291184814d9fea5fdf4fa16f5a4f7116"
|
||||
dependencies = [
|
||||
"getrandom",
|
||||
"log",
|
||||
|
||||
@@ -89,7 +89,7 @@ pyo3-asyncio = { git = "https://github.com/jopemachine/pyo3-asyncio/", branch =
|
||||
"tokio-runtime",
|
||||
], default-features = false, optional = true }
|
||||
pyo3-log = { version = "0.10.0", default-features = false, optional = true }
|
||||
tract-onnx = { git = "https://github.com/sonos/tract/", rev = "37132e0397d0a73e5bd3a8615d932dabe44f6736", default-features = false, optional = true }
|
||||
tract-onnx = { git = "https://github.com/sonos/tract/", rev = "40c64319291184814d9fea5fdf4fa16f5a4f7116", default-features = false, optional = true }
|
||||
tabled = { version = "0.12.0", optional = true }
|
||||
metal = { git = "https://github.com/gfx-rs/metal-rs", optional = true }
|
||||
objc = { version = "0.2.4", optional = true }
|
||||
|
||||
@@ -1,52 +1,75 @@
|
||||
from torch import nn
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import sys
|
||||
import json
|
||||
import numpy as np
|
||||
import tf2onnx
|
||||
|
||||
sys.path.append("..")
|
||||
|
||||
class Model(nn.Module):
|
||||
"""
|
||||
Just one Linear layer
|
||||
"""
|
||||
def __init__(self, configs):
|
||||
super(Model, self).__init__()
|
||||
self.seq_len = configs.seq_len
|
||||
self.pred_len = configs.pred_len
|
||||
|
||||
# Use this line if you want to visualize the weights
|
||||
# self.Linear.weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len]))
|
||||
self.channels = configs.enc_in
|
||||
self.individual = configs.individual
|
||||
if self.individual:
|
||||
self.Linear = nn.ModuleList()
|
||||
for i in range(self.channels):
|
||||
self.Linear.append(nn.Linear(self.seq_len,self.pred_len))
|
||||
else:
|
||||
self.Linear = nn.Linear(self.seq_len, self.pred_len)
|
||||
|
||||
def forward(self, x):
|
||||
# x: [Batch, Input length, Channel]
|
||||
if self.individual:
|
||||
output = torch.zeros([x.size(0),self.pred_len,x.size(2)],dtype=x.dtype).to(x.device)
|
||||
for i in range(self.channels):
|
||||
output[:,:,i] = self.Linear[i](x[:,:,i])
|
||||
x = output
|
||||
else:
|
||||
x = self.Linear(x.permute(0,2,1)).permute(0,2,1)
|
||||
return x # [Batch, Output length, Channel]
|
||||
|
||||
class Configs:
|
||||
def __init__(self, seq_len, pred_len, enc_in=321, individual=True):
|
||||
self.seq_len = seq_len
|
||||
self.pred_len = pred_len
|
||||
self.enc_in = enc_in
|
||||
self.individual = individual
|
||||
|
||||
model = 'Linear'
|
||||
seq_len = 10
|
||||
pred_len = 4
|
||||
enc_in = 3
|
||||
|
||||
configs = Configs(seq_len, pred_len, enc_in, True)
|
||||
circuit = Model(configs)
|
||||
|
||||
x = torch.randn(1, seq_len, pred_len)
|
||||
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow.keras.layers import *
|
||||
from tensorflow.keras.models import Model
|
||||
torch.onnx.export(circuit, x, "network.onnx",
|
||||
export_params=True, # store the trained parameter weights inside the model file
|
||||
opset_version=15, # the ONNX version to export the model to
|
||||
do_constant_folding=True, # whether to execute constant folding for optimization
|
||||
# the model's input names
|
||||
input_names=['input'],
|
||||
output_names=['output'], # the model's output names
|
||||
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
|
||||
'output': {0: 'batch_size'}})
|
||||
|
||||
|
||||
# gather_nd in tf then export to onnx
|
||||
x = in1 = Input((4, 1), dtype=tf.int32)
|
||||
w = in2 = Input((4, ), dtype=tf.int32)
|
||||
|
||||
class MyLayer(Layer):
|
||||
def call(self, x, w):
|
||||
shape = tf.constant([8])
|
||||
return tf.scatter_nd(x, w, shape)
|
||||
|
||||
x = MyLayer()(x, w)
|
||||
|
||||
|
||||
|
||||
tm = Model((in1, in2), x)
|
||||
tm.summary()
|
||||
tm.compile(optimizer='adam', loss='mse')
|
||||
|
||||
shape = [1, 4, 1]
|
||||
index_shape = [1, 4]
|
||||
# After training, export to onnx (network.onnx) and create a data file (input.json)
|
||||
x = np.random.randint(0, 4, shape)
|
||||
# w = random int tensor
|
||||
w = np.random.randint(0, 4, index_shape)
|
||||
|
||||
spec = tf.TensorSpec(shape, tf.int32, name='input_0')
|
||||
index_spec = tf.TensorSpec(index_shape, tf.int32, name='input_1')
|
||||
|
||||
model_path = "network.onnx"
|
||||
|
||||
tf2onnx.convert.from_keras(tm, input_signature=[spec, index_spec], inputs_as_nchw=['input_0', 'input_1'], opset=12, output_path=model_path)
|
||||
|
||||
|
||||
d = x.reshape([-1]).tolist()
|
||||
d1 = w.reshape([-1]).tolist()
|
||||
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
|
||||
|
||||
|
||||
data = dict(
|
||||
input_data=[d, d1],
|
||||
input_data=[d1],
|
||||
)
|
||||
|
||||
# Serialize data into file:
|
||||
|
||||
@@ -1,16 +1 @@
|
||||
{
|
||||
"input_data": [
|
||||
[
|
||||
0,
|
||||
1,
|
||||
2,
|
||||
3
|
||||
],
|
||||
[
|
||||
1,
|
||||
0,
|
||||
2,
|
||||
1
|
||||
]
|
||||
]
|
||||
}
|
||||
{"input_data": [[0.1874287724494934, 1.0498261451721191, 0.22384068369865417, 1.048445224761963, -0.5670360326766968, -0.38653188943862915, 0.12878702580928802, -2.3675858974456787, 0.5800458192825317, -0.43653929233551025, -0.2511898875236511, 0.3324051797389984, 0.27960312366485596, 0.4763695001602173, 0.3796705901622772, 1.1334782838821411, -0.87981778383255, -1.2451434135437012, 0.7672272324562073, -0.24404007196426392, -0.6875824928283691, 0.3619358539581299, -0.10131897777318954, 0.7169521450996399, 1.6585893630981445, -0.5451845526695251, 0.429487019777298, 0.7426952123641968, -0.2543637454509735, 0.06546942889690399, 0.7939824461936951, 0.1579471379518509, -0.043604474514722824, -0.8621711730957031, -0.5344759821891785, -0.05880478024482727, -0.17351101338863373, 0.5095029473304749, -0.7864817976951599, -0.449171245098114]]}
|
||||
Binary file not shown.
@@ -1687,7 +1687,6 @@ pub(crate) fn linearize_nd_index<F: PrimeField + TensorType + PartialOrd + std::
|
||||
Ok(output.into())
|
||||
}
|
||||
|
||||
// assumes unique values in fullset
|
||||
pub(crate) fn get_missing_set_elements<
|
||||
F: PrimeField + TensorType + PartialOrd + std::hash::Hash,
|
||||
>(
|
||||
@@ -1700,8 +1699,6 @@ pub(crate) fn get_missing_set_elements<
|
||||
let set_len = fullset.len();
|
||||
input.flatten();
|
||||
|
||||
// while fullset is less than len of input concat
|
||||
|
||||
let is_assigned = !input.any_unknowns()? && !fullset.any_unknowns()?;
|
||||
|
||||
let mut claimed_output: ValTensor<F> = if is_assigned {
|
||||
|
||||
@@ -656,7 +656,7 @@ impl Model {
|
||||
|
||||
let mut symbol_values = SymbolValues::default();
|
||||
for (symbol, value) in run_args.variables.iter() {
|
||||
let symbol = model.symbols.sym(symbol);
|
||||
let symbol = model.symbol_table.sym(symbol);
|
||||
symbol_values = symbol_values.with(&symbol, *value as i64);
|
||||
debug!("set {} to {}", symbol, value);
|
||||
}
|
||||
@@ -1199,9 +1199,9 @@ impl Model {
|
||||
// Then number of columns in the circuits
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
region.debug_report();
|
||||
trace!("input indices: {:?}", node.inputs());
|
||||
trace!("output scales: {:?}", node.out_scales());
|
||||
trace!(
|
||||
debug!("input indices: {:?}", node.inputs());
|
||||
debug!("output scales: {:?}", node.out_scales());
|
||||
debug!(
|
||||
"input scales: {:?}",
|
||||
node.inputs()
|
||||
.iter()
|
||||
@@ -1220,8 +1220,8 @@ impl Model {
|
||||
// we re-assign inputs, always from the 0 outlet
|
||||
vec![results.get(idx).ok_or(GraphError::MissingResults)?[0].clone()]
|
||||
};
|
||||
trace!("output dims: {:?}", node.out_dims());
|
||||
trace!(
|
||||
debug!("output dims: {:?}", node.out_dims());
|
||||
debug!(
|
||||
"input dims {:?}",
|
||||
values.iter().map(|v| v.dims()).collect_vec()
|
||||
);
|
||||
|
||||
@@ -1007,21 +1007,21 @@ pub fn new_op_from_onnx(
|
||||
op
|
||||
}
|
||||
"Iff" => SupportedOp::Linear(PolyOp::Iff),
|
||||
"<" => {
|
||||
"Less" => {
|
||||
if inputs.len() == 2 {
|
||||
SupportedOp::Hybrid(HybridOp::Less)
|
||||
} else {
|
||||
return Err(GraphError::InvalidDims(idx, "less".to_string()));
|
||||
}
|
||||
}
|
||||
"<=" => {
|
||||
"LessEqual" => {
|
||||
if inputs.len() == 2 {
|
||||
SupportedOp::Hybrid(HybridOp::LessEqual)
|
||||
} else {
|
||||
return Err(GraphError::InvalidDims(idx, "less equal".to_string()));
|
||||
}
|
||||
}
|
||||
">" => {
|
||||
"Greater" => {
|
||||
// Extract the slope layer hyperparams
|
||||
if inputs.len() == 2 {
|
||||
SupportedOp::Hybrid(HybridOp::Greater)
|
||||
@@ -1029,7 +1029,7 @@ pub fn new_op_from_onnx(
|
||||
return Err(GraphError::InvalidDims(idx, "greater".to_string()));
|
||||
}
|
||||
}
|
||||
">=" => {
|
||||
"GreaterEqual" => {
|
||||
// Extract the slope layer hyperparams
|
||||
if inputs.len() == 2 {
|
||||
SupportedOp::Hybrid(HybridOp::GreaterEqual)
|
||||
@@ -1250,7 +1250,7 @@ pub fn new_op_from_onnx(
|
||||
"And" => SupportedOp::Linear(PolyOp::And),
|
||||
"Or" => SupportedOp::Linear(PolyOp::Or),
|
||||
"Xor" => SupportedOp::Linear(PolyOp::Xor),
|
||||
"==" => SupportedOp::Hybrid(HybridOp::Equals),
|
||||
"Equals" => SupportedOp::Hybrid(HybridOp::Equals),
|
||||
"Deconv" => {
|
||||
let deconv_node: &Deconv = match node.op().downcast_ref::<Deconv>() {
|
||||
Some(b) => b,
|
||||
|
||||
@@ -1109,13 +1109,6 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
///
|
||||
/// ```
|
||||
pub fn expand(&self, shape: &[usize]) -> Result<Self, TensorError> {
|
||||
// if both have length 1 then we can just return the tensor
|
||||
if self.dims().iter().product::<usize>() == 1 && shape.iter().product::<usize>() == 1 {
|
||||
let mut output = self.clone();
|
||||
output.reshape(shape)?;
|
||||
return Ok(output);
|
||||
}
|
||||
|
||||
if self.dims().len() > shape.len() {
|
||||
return Err(TensorError::DimError(format!(
|
||||
"Cannot expand {:?} to the smaller shape {:?}",
|
||||
|
||||
@@ -1050,7 +1050,6 @@ pub fn scatter_nd<T: TensorType + Send + Sync>(
|
||||
let slice = coord.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
|
||||
let index_val = index.get_slice(&slice)?;
|
||||
let index_slice = index_val.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
|
||||
|
||||
let src_val = src.get_slice(&slice)?;
|
||||
output.set_slice(&index_slice, &src_val)?;
|
||||
Ok::<_, TensorError>(())
|
||||
|
||||
@@ -636,7 +636,7 @@ mod native_tests {
|
||||
#(#[test_case(TESTS[N])])*
|
||||
fn mock_large_batch_public_outputs_(test: &str) {
|
||||
// currently variable output rank is not supported in ONNX
|
||||
if test != "gather_nd" && test != "lstm_large" && test != "lstm_medium" && test != "scatter_nd" {
|
||||
if test != "gather_nd" && test != "lstm_large" && test != "lstm_medium" {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
|
||||
Reference in New Issue
Block a user