Compare commits

...

2 Commits

Author SHA1 Message Date
jmjac
de9e3f2673 Add __version__ to python bindings (#739) 2024-03-13 14:22:20 +00:00
dante
a1450f8df7 feat: gather_nd/scatter_nd support (#737) 2024-03-11 22:05:40 +00:00
17 changed files with 946 additions and 16 deletions

View File

@@ -0,0 +1,48 @@
from torch import nn
import json
import numpy as np
import tf2onnx
import tensorflow as tf
from tensorflow.keras.layers import *
from tensorflow.keras.models import Model
# gather_nd in tf then export to onnx
x = in1 = Input((15, 18,))
w = in2 = Input((15, 1), dtype=tf.int32)
x = tf.gather_nd(x, w, batch_dims=1)
tm = Model((in1, in2), x )
tm.summary()
tm.compile(optimizer='adam', loss='mse')
shape = [1, 15, 18]
index_shape = [1, 15, 1]
# After training, export to onnx (network.onnx) and create a data file (input.json)
x = 0.1*np.random.rand(1,*shape)
# w = random int tensor
w = np.random.randint(0, 10, index_shape)
spec = tf.TensorSpec(shape, tf.float32, name='input_0')
index_spec = tf.TensorSpec(index_shape, tf.int32, name='input_1')
model_path = "network.onnx"
tf2onnx.convert.from_keras(tm, input_signature=[spec, index_spec], inputs_as_nchw=['input_0', 'input_1'], opset=12, output_path=model_path)
d = x.reshape([-1]).tolist()
d1 = w.reshape([-1]).tolist()
data = dict(
input_data=[d, d1],
)
# Serialize data into file:
json.dump(data, open("input.json", 'w'))

File diff suppressed because one or more lines are too long

Binary file not shown.

View File

@@ -0,0 +1,76 @@
import torch
import torch.nn as nn
import sys
import json
sys.path.append("..")
class Model(nn.Module):
"""
Just one Linear layer
"""
def __init__(self, configs):
super(Model, self).__init__()
self.seq_len = configs.seq_len
self.pred_len = configs.pred_len
# Use this line if you want to visualize the weights
# self.Linear.weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len]))
self.channels = configs.enc_in
self.individual = configs.individual
if self.individual:
self.Linear = nn.ModuleList()
for i in range(self.channels):
self.Linear.append(nn.Linear(self.seq_len,self.pred_len))
else:
self.Linear = nn.Linear(self.seq_len, self.pred_len)
def forward(self, x):
# x: [Batch, Input length, Channel]
if self.individual:
output = torch.zeros([x.size(0),self.pred_len,x.size(2)],dtype=x.dtype).to(x.device)
for i in range(self.channels):
output[:,:,i] = self.Linear[i](x[:,:,i])
x = output
else:
x = self.Linear(x.permute(0,2,1)).permute(0,2,1)
return x # [Batch, Output length, Channel]
class Configs:
def __init__(self, seq_len, pred_len, enc_in=321, individual=True):
self.seq_len = seq_len
self.pred_len = pred_len
self.enc_in = enc_in
self.individual = individual
model = 'Linear'
seq_len = 10
pred_len = 4
enc_in = 3
configs = Configs(seq_len, pred_len, enc_in, True)
circuit = Model(configs)
x = torch.randn(1, seq_len, pred_len)
torch.onnx.export(circuit, x, "network.onnx",
export_params=True, # store the trained parameter weights inside the model file
opset_version=15, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
# the model's input names
input_names=['input'],
output_names=['output'], # the model's output names
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
'output': {0: 'batch_size'}})
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
data = dict(
input_data=[d1],
)
# Serialize data into file:
json.dump(data, open("input.json", 'w'))

View File

@@ -0,0 +1 @@
{"input_data": [[0.1874287724494934, 1.0498261451721191, 0.22384068369865417, 1.048445224761963, -0.5670360326766968, -0.38653188943862915, 0.12878702580928802, -2.3675858974456787, 0.5800458192825317, -0.43653929233551025, -0.2511898875236511, 0.3324051797389984, 0.27960312366485596, 0.4763695001602173, 0.3796705901622772, 1.1334782838821411, -0.87981778383255, -1.2451434135437012, 0.7672272324562073, -0.24404007196426392, -0.6875824928283691, 0.3619358539581299, -0.10131897777318954, 0.7169521450996399, 1.6585893630981445, -0.5451845526695251, 0.429487019777298, 0.7426952123641968, -0.2543637454509735, 0.06546942889690399, 0.7939824461936951, 0.1579471379518509, -0.043604474514722824, -0.8621711730957031, -0.5344759821891785, -0.05880478024482727, -0.17351101338863373, 0.5095029473304749, -0.7864817976951599, -0.449171245098114]]}

Binary file not shown.

View File

@@ -972,6 +972,55 @@ pub(crate) fn gather_elements<F: PrimeField + TensorType + PartialOrd>(
Ok((output, linear_index))
}
/// Gather accumulated layout
pub(crate) fn gather_nd<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 2],
batch_dims: usize,
) -> Result<(ValTensor<F>, ValTensor<F>), Box<dyn Error>> {
let (input, index) = (values[0].clone(), values[1].clone());
let index_dims = index.dims().to_vec();
let input_dims = input.dims().to_vec();
let last_value = index_dims
.last()
.ok_or(TensorError::DimMismatch("gather_nd".to_string()))?;
if index_dims.last() > Some(&(input_dims.len() - batch_dims)) {
return Err(TensorError::DimMismatch("gather_nd".to_string()).into());
}
let output_size =
// If indices_shape[-1] == r-b, since the rank of indices is q,
// indices can be thought of as N (q-b-1)-dimensional tensors containing 1-D tensors of dimension r-b,
// where N is an integer equals to the product of 1 and all the elements in the batch dimensions of the indices_shape.
// Let us think of each such r-b ranked tensor as indices_slice.
// Each scalar value corresponding to data[0:b-1,indices_slice] is filled into
// the corresponding location of the (q-b-1)-dimensional tensor to form the output tensor
// if indices_shape[-1] < r-b, since the rank of indices is q, indices can be thought of as N (q-b-1)-dimensional tensor containing 1-D tensors of dimension < r-b.
// Let us think of each such tensors as indices_slice.
// Each tensor slice corresponding to data[0:b-1, indices_slice , :] is filled into the corresponding location of the (q-b-1)-dimensional tensor to form the output tensor
{
let output_rank = input_dims.len() + index_dims.len() - 1 - batch_dims - last_value;
let mut dims = index_dims[..index_dims.len() - 1].to_vec();
let input_offset = batch_dims + last_value;
dims.extend(input_dims[input_offset..input_dims.len()].to_vec());
assert_eq!(output_rank, dims.len());
dims
};
let linear_index = linearize_nd_index(config, region, &[index], input.dims(), batch_dims)?;
let mut output = select(config, region, &[input, linear_index.clone()])?;
output.reshape(&output_size)?;
Ok((output, linear_index))
}
/// Takes a tensor representing a multi-dimensional index and returns a tensor representing the linearized index.
/// The linearized index is the index of the element in the flattened tensor.
/// FOr instance if the dims is [3,5,2], the linearized index of [2] at dim 1 is 2*5 + 3 = 13
@@ -1059,6 +1108,171 @@ pub(crate) fn linearize_element_index<F: PrimeField + TensorType + PartialOrd>(
Ok(output.into())
}
/// Takes a tensor representing a nd index and returns a tensor representing the linearized index.
/// The linearized index is the index of the element in the flattened tensor.
/// Given data tensor of rank r >= 1, indices tensor of rank q >= 1, and batch_dims integer b, this operator gathers slices of data into an output tensor of rank q + r - indices_shape[-1] - 1 - b.
/// indices is an q-dimensional integer tensor, best thought of as a (q-1)-dimensional tensor of index-tuples into data, where each element defines a slice of data
/// batch_dims (denoted as b) is an integer indicating the number of batch dimensions, i.e the leading b number of dimensions of data tensor and indices are representing the batches, and the gather starts from the b+1 dimension.
/// Some salient points about the inputs rank and shape:
/// r >= 1 and q >= 1 are to be honored. There is no dependency condition to be met between ranks r and q
/// The first b dimensions of the shape of indices tensor and data tensor must be equal.
/// b < min(q, r) is to be honored.
/// The indices_shape[-1] should have a value between 1 (inclusive) and rank r-b (inclusive)
/// All values in indices are expected to be within bounds [-s, s-1] along axis of size s (i.e.) -data_shape[i] <= indices[...,i] <= data_shape[i] - 1. It is an error if any of the index values are out of bounds.
// The output is computed as follows:
/// The output tensor is obtained by mapping each index-tuple in the indices tensor to the corresponding slice of the input data.
/// If indices_shape[-1] > r-b => error condition
/// If indices_shape[-1] == r-b, since the rank of indices is q, indices can be thought of as N (q-b-1)-dimensional tensors containing 1-D tensors of dimension r-b, where N is an integer equals to the product of 1 and all the elements in the batch dimensions of the indices_shape.
/// Let us think of each such r-b ranked tensor as indices_slice. Each scalar value corresponding to data[0:b-1,indices_slice] is filled into the corresponding location of the (q-b-1)-dimensional tensor to form the output tensor (Example 1 below)
/// If indices_shape[-1] < r-b, since the rank of indices is q, indices can be thought of as N (q-b-1)-dimensional tensor containing 1-D tensors of dimension < r-b. Let us think of each such tensors as indices_slice. Each tensor slice corresponding to data[0:b-1, indices_slice , :] is filled into the corresponding location of the (q-b-1)-dimensional tensor to form the output tensor (Examples 2, 3, 4 and 5 below)
pub(crate) fn linearize_nd_index<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
dims: &[usize],
batch_dims: usize,
) -> Result<ValTensor<F>, Box<dyn Error>> {
let index = values[0].clone();
let index_dims = index.dims().to_vec();
let last_dim = index.dims().last().unwrap();
let input_rank = dims[batch_dims..].len();
let dim_multiplier: Tensor<usize> = Tensor::new(None, &[dims.len()])?;
let dim_multiplier: Tensor<F> = dim_multiplier.par_enum_map(|i, _| {
let mut res = 1;
for dim in dims.iter().skip(i + 1) {
res *= dim;
}
Ok::<_, region::RegionError>(F::from(res as u64))
})?;
let iteration_dims = index.dims()[0..batch_dims].to_vec();
let mut batch_cartesian_coord = iteration_dims
.iter()
.map(|x| 0..*x)
.multi_cartesian_product()
.collect::<Vec<_>>();
if batch_cartesian_coord.is_empty() {
batch_cartesian_coord.push(vec![]);
}
let index_dim_multiplier: ValTensor<F> = dim_multiplier
.get_slice(&[batch_dims..dims.len()])?
.map(|x| ValType::Constant(x))
.into();
let mut outer_results = vec![];
for coord in batch_cartesian_coord {
let slice: Vec<Range<usize>> = coord.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
let mut index_slice = index.get_slice(&slice)?;
index_slice.reshape(&index_dims[batch_dims..])?;
// expand the index to the full dims by iterating over the rest of the dims and inserting constants
// eg in the case
// batch_dims = 0
// data = [[[0,1],[2,3]],[[4,5],[6,7]]] # data_shape = [2, 2, 2]
// indices = [[0,1],[1,0]] # indices_shape = [2, 2]
// output = [[2,3],[4,5]] # output_shape = [2, 2]
// the index should be expanded to the shape [2,2,3]: [[0,1,0],[0,1,1],[1,0,0],[1,0,1]]
let mut inner_cartesian_coord = index_slice.dims()[0..index_slice.dims().len() - 1]
.iter()
.map(|x| 0..*x)
.multi_cartesian_product()
.collect::<Vec<_>>();
if inner_cartesian_coord.is_empty() {
inner_cartesian_coord.push(vec![]);
}
let indices = if last_dim < &input_rank {
inner_cartesian_coord
.iter()
.map(|x| {
let slice = x.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
let index = index_slice.get_slice(&slice)?;
// map over cartesian coord of rest of dims and insert constants
let grid = (*last_dim..input_rank)
.map(|x| 0..dims[x])
.multi_cartesian_product();
Ok(grid
.map(|x| {
let index = index.clone();
let constant_valtensor: ValTensor<F> = Tensor::from(
x.into_iter().map(|x| ValType::Constant(F::from(x as u64))),
)
.into();
index.concat(constant_valtensor)
})
.collect::<Result<Vec<_>, TensorError>>()?)
})
.collect::<Result<Vec<_>, Box<dyn Error>>>()?
.into_iter()
.flatten()
.collect::<Vec<_>>()
} else {
inner_cartesian_coord
.iter()
.map(|x| {
let slice = x.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
Ok(index_slice.get_slice(&slice)?)
})
.collect::<Result<Vec<_>, Box<dyn Error>>>()?
};
let mut const_offset = F::ZERO;
for i in 0..batch_dims {
const_offset += F::from(coord[i] as u64) * dim_multiplier[i];
}
let const_offset = create_constant_tensor(const_offset, 1);
let mut results = vec![];
for index_val in indices {
let mut index_val = index_val.clone();
index_val.flatten();
let res = pairwise(
config,
region,
&[index_val.clone(), index_dim_multiplier.clone()],
BaseOp::Mult,
)?;
let res = res.concat(const_offset.clone())?;
let res = sum(config, region, &[res])?;
results.push(res.get_inner_tensor()?.clone());
// assert than res is less than the product of the dims
assert!(
res.get_int_evals()?
.iter()
.all(|x| *x < dims.iter().product::<usize>() as i128),
"res is greater than the product of the dims {} (coord={}, index_dim_multiplier={}, res={})",
dims.iter().product::<usize>(),
index_val.show(),
index_dim_multiplier.show(),
res.show()
);
}
let result_tensor = Tensor::from(results.into_iter());
outer_results.push(result_tensor.combine()?);
}
let output = Tensor::from(outer_results.into_iter());
let output = output.combine()?;
Ok(output.into())
}
pub(crate) fn get_missing_set_elements<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
@@ -1133,8 +1347,6 @@ pub(crate) fn scatter_elements<F: PrimeField + TensorType + PartialOrd>(
assert_eq!(input.dims().len(), index.dims().len());
let input_dims = input.dims();
if !index.all_prev_assigned() {
index = region.assign(&config.custom_gates.inputs[1], &index)?;
region.increment(index.len());
@@ -1201,7 +1413,86 @@ pub(crate) fn scatter_elements<F: PrimeField + TensorType + PartialOrd>(
&[full_index_set, input.clone()],
)?;
claimed_output.reshape(input_dims)?;
claimed_output.reshape(input.dims())?;
Ok(claimed_output)
}
/// Scatter Nd
pub(crate) fn scatter_nd<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 3],
) -> Result<ValTensor<F>, Box<dyn Error>> {
let (input, mut index, src) = (values[0].clone(), values[1].clone(), values[2].clone());
if !index.all_prev_assigned() {
index = region.assign(&config.custom_gates.inputs[1], &index)?;
region.increment(index.len());
}
let is_assigned = !input.any_unknowns()? && !index.any_unknowns()? && !src.any_unknowns()?;
let claimed_output: ValTensor<F> = if is_assigned {
let input_inner = input.get_int_evals()?;
let index_inner = index.get_int_evals()?.map(|x| x as usize);
let src_inner = src.get_int_evals()?;
let res = tensor::ops::scatter_nd(&input_inner, &index_inner, &src_inner)?;
res.iter()
.map(|x| Value::known(i128_to_felt(*x)))
.collect::<Tensor<Value<F>>>()
.into()
} else {
Tensor::new(
Some(&vec![Value::<F>::unknown(); input.len()]),
&[input.len()],
)?
.into()
};
// assign the claimed output
let mut claimed_output = region.assign(&config.custom_gates.output, &claimed_output)?;
region.increment(claimed_output.len());
claimed_output.reshape(input.dims())?;
// scatter elements is the inverse of gather elements
let (gather_src, linear_index) =
gather_nd(config, region, &[claimed_output.clone(), index.clone()], 0)?;
// assert this is equal to the src
enforce_equality(config, region, &[gather_src, src])?;
let full_index_set: ValTensor<F> =
Tensor::from((0..input.len() as u64).map(|x| ValType::Constant(F::from(x)))).into();
let input_indices = get_missing_set_elements(
config,
region,
&[linear_index, full_index_set.clone()],
true,
)?;
// now that it is flattened we can gather over elements on dim 0
claimed_output.flatten();
let (gather_input, _) = gather_elements(
config,
region,
&[claimed_output.clone(), input_indices.clone()],
0,
)?;
// assert this is a subset of the input
dynamic_lookup(
config,
region,
&[input_indices, gather_input],
&[full_index_set, input.clone()],
)?;
claimed_output.reshape(input.dims())?;
Ok(claimed_output)
}

View File

@@ -14,10 +14,17 @@ pub enum PolyOp {
dim: usize,
constant_idx: Option<Tensor<usize>>,
},
GatherND {
batch_dims: usize,
indices: Option<Tensor<usize>>,
},
ScatterElements {
dim: usize,
constant_idx: Option<Tensor<usize>>,
},
ScatterND {
constant_idx: Option<Tensor<usize>>,
},
MultiBroadcastTo {
shape: Vec<usize>,
},
@@ -89,7 +96,9 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
fn as_string(&self) -> String {
match &self {
PolyOp::GatherElements { dim, .. } => format!("GATHERELEMENTS (dim={})", dim),
PolyOp::GatherND { batch_dims, .. } => format!("GATHERND (batch_dims={})", batch_dims),
PolyOp::ScatterElements { dim, .. } => format!("SCATTERELEMENTS (dim={})", dim),
PolyOp::ScatterND { .. } => "SCATTERND".into(),
PolyOp::MultiBroadcastTo { shape } => format!("MULTIBROADCASTTO (shape={:?})", shape),
PolyOp::MoveAxis { .. } => "MOVEAXIS".into(),
PolyOp::Downsample { .. } => "DOWNSAMPLE".into(),
@@ -213,6 +222,18 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
};
tensor::ops::gather_elements(&x, &y, *dim)
}
PolyOp::GatherND {
indices,
batch_dims,
} => {
let x = inputs[0].clone();
let y = if let Some(idx) = indices {
idx.clone()
} else {
inputs[1].clone().map(|x| felt_to_i128(x) as usize)
};
tensor::ops::gather_nd(&x, &y, *batch_dims)
}
PolyOp::ScatterElements { dim, constant_idx } => {
let x = inputs[0].clone();
@@ -229,6 +250,21 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
};
tensor::ops::scatter(&x, &idx, &src, *dim)
}
PolyOp::ScatterND { constant_idx } => {
let x = inputs[0].clone();
let idx = if let Some(idx) = constant_idx {
idx.clone()
} else {
inputs[1].clone().map(|x| felt_to_i128(x) as usize)
};
let src = if constant_idx.is_some() {
inputs[1].clone()
} else {
inputs[2].clone()
};
tensor::ops::scatter_nd(&x, &idx, &src)
}
}?;
Ok(ForwardResult { output: res })
@@ -279,6 +315,16 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
layouts::gather_elements(config, region, values[..].try_into()?, *dim)?.0
}
}
PolyOp::GatherND {
batch_dims,
indices,
} => {
if let Some(idx) = indices {
tensor::ops::gather_nd(values[0].get_inner_tensor()?, idx, *batch_dims)?.into()
} else {
layouts::gather_nd(config, region, values[..].try_into()?, *batch_dims)?.0
}
}
PolyOp::ScatterElements { dim, constant_idx } => {
if let Some(idx) = constant_idx {
tensor::ops::scatter(
@@ -292,6 +338,18 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
layouts::scatter_elements(config, region, values[..].try_into()?, *dim)?
}
}
PolyOp::ScatterND { constant_idx } => {
if let Some(idx) = constant_idx {
tensor::ops::scatter_nd(
values[0].get_inner_tensor()?,
idx,
values[1].get_inner_tensor()?,
)?
.into()
} else {
layouts::scatter_nd(config, region, values[..].try_into()?)?
}
}
PolyOp::DeConv {
padding,
output_padding,
@@ -389,7 +447,9 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
vec![1, 2]
} else if matches!(self, PolyOp::Concat { .. }) {
(0..100).collect()
} else if matches!(self, PolyOp::ScatterElements { .. }) {
} else if matches!(self, PolyOp::ScatterElements { .. })
| matches!(self, PolyOp::ScatterND { .. })
{
vec![0, 2]
} else {
vec![]

View File

@@ -23,7 +23,10 @@ use std::sync::Arc;
use tract_onnx::prelude::{DatumType, Node as OnnxNode, TypedFact, TypedOp};
#[cfg(not(target_arch = "wasm32"))]
use tract_onnx::tract_core::ops::{
array::{Gather, GatherElements, MultiBroadcastTo, OneHot, ScatterElements, Slice, Topk},
array::{
Gather, GatherElements, GatherNd, MultiBroadcastTo, OneHot, ScatterElements, ScatterNd,
Slice, Topk,
},
change_axes::AxisOp,
cnn::{Conv, Deconv},
einsum::EinSum,
@@ -467,6 +470,78 @@ pub fn new_op_from_onnx(
// Extract the max value
}
"ScatterNd" => {
if inputs.len() != 3 {
return Err(Box::new(GraphError::InvalidDims(
idx,
"scatter nd".to_string(),
)));
};
// just verify it deserializes correctly
let _op = load_op::<ScatterNd>(node.op(), idx, node.op().name().to_string())?;
let mut op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::ScatterND {
constant_idx: None,
});
// if param_visibility.is_public() {
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
inputs[1].decrement_use();
deleted_indices.push(inputs.len() - 1);
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::ScatterND {
constant_idx: Some(c.raw_values.map(|x| x as usize)),
})
}
// }
if inputs[1].opkind().is_input() {
inputs[1].replace_opkind(SupportedOp::Input(crate::circuit::ops::Input {
scale: 0,
datum_type: InputType::TDim,
}));
inputs[1].bump_scale(0);
}
op
}
"GatherNd" => {
if inputs.len() != 2 {
return Err(Box::new(GraphError::InvalidDims(
idx,
"gather nd".to_string(),
)));
};
let op = load_op::<GatherNd>(node.op(), idx, node.op().name().to_string())?;
let batch_dims = op.batch_dims;
let mut op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::GatherND {
batch_dims,
indices: None,
});
// if param_visibility.is_public() {
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
inputs[1].decrement_use();
deleted_indices.push(inputs.len() - 1);
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::GatherND {
batch_dims,
indices: Some(c.raw_values.map(|x| x as usize)),
})
}
// }
if inputs[1].opkind().is_input() {
inputs[1].replace_opkind(SupportedOp::Input(crate::circuit::ops::Input {
scale: 0,
datum_type: InputType::TDim,
}));
inputs[1].bump_scale(0);
}
op
}
"GatherElements" => {
if inputs.len() != 2 {
return Err(Box::new(GraphError::InvalidDims(

View File

@@ -1103,6 +1103,7 @@ fn ezkl(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_class::<PyG1Affine>()?;
m.add_class::<PyG1>()?;
m.add_class::<PyTestDataSource>()?;
m.add("__version__", env!("CARGO_PKG_VERSION"))?;
m.add_function(wrap_pyfunction!(felt_to_big_endian, m)?)?;
m.add_function(wrap_pyfunction!(felt_to_int, m)?)?;
m.add_function(wrap_pyfunction!(felt_to_float, m)?)?;

View File

@@ -673,6 +673,68 @@ impl<T: Clone + TensorType> Tensor<T> {
Tensor::new(Some(&res), &dims)
}
/// Set a slice of the Tensor.
/// ```
/// use ezkl::tensor::Tensor;
/// let mut a = Tensor::<i32>::new(Some(&[1, 2, 3, 4, 5, 6]), &[2, 3]).unwrap();
/// let b = Tensor::<i32>::new(Some(&[1, 2, 3, 1, 2, 3]), &[2, 3]).unwrap();
/// a.set_slice(&[1..2], &Tensor::<i32>::new(Some(&[1, 2, 3]), &[1, 3]).unwrap()).unwrap();
/// assert_eq!(a, b);
/// ```
pub fn set_slice(
&mut self,
indices: &[Range<usize>],
value: &Tensor<T>,
) -> Result<(), TensorError>
where
T: Send + Sync,
{
if indices.is_empty() {
return Ok(());
}
if self.dims.len() < indices.len() {
return Err(TensorError::DimError(format!(
"The dimensionality of the slice {:?} is greater than the tensor's {:?}",
indices, self.dims
)));
}
// if indices weren't specified we fill them in as required
let mut full_indices = indices.to_vec();
let omitted_dims = (indices.len()..self.dims.len())
.map(|i| self.dims[i])
.collect::<Vec<_>>();
for dim in &omitted_dims {
full_indices.push(0..*dim);
}
let full_dims = full_indices
.iter()
.map(|x| x.end - x.start)
.collect::<Vec<_>>();
// now broadcast the value to the full dims
let value = value.expand(&full_dims)?;
let cartesian_coord: Vec<Vec<usize>> = full_indices
.iter()
.cloned()
.multi_cartesian_product()
.collect();
let _ = cartesian_coord
.iter()
.enumerate()
.map(|(i, e)| {
self.set(e, value[i].clone());
})
.collect::<Vec<_>>();
Ok(())
}
/// Get the array index from rows / columns indices.
///
/// ```

View File

@@ -2,7 +2,7 @@ use super::TensorError;
use crate::tensor::{Tensor, TensorType};
use itertools::Itertools;
use maybe_rayon::{
iter::IndexedParallelIterator, iter::IntoParallelRefMutIterator, iter::ParallelIterator,
iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator},
prelude::IntoParallelRefIterator,
};
use std::collections::{HashMap, HashSet};
@@ -1328,6 +1328,316 @@ pub fn gather_elements<T: TensorType + Send + Sync>(
Ok(output)
}
/// Gather ND.
/// # Arguments
/// * `input` - Tensor
/// * `index` - Tensor of indices to gather
/// * `batch_dims` - Number of batch dimensions
/// # Examples
/// ```
/// use ezkl::tensor::Tensor;
/// use ezkl::tensor::ops::gather_nd;
/// let x = Tensor::<i128>::new(
/// Some(&[0, 1, 2, 3]),
/// &[2, 2],
/// ).unwrap();
/// let index = Tensor::<usize>::new(
/// Some(&[0, 0, 1, 1]),
/// &[2, 2],
/// ).unwrap();
/// let result = gather_nd(&x, &index, 0).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[0, 3]), &[2]).unwrap();
/// assert_eq!(result, expected);
///
/// let index = Tensor::<usize>::new(
/// Some(&[1, 0]),
/// &[2, 1],
/// ).unwrap();
/// let result = gather_nd(&x, &index, 0).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[2, 3, 0, 1]), &[2, 2]).unwrap();
/// assert_eq!(result, expected);
///
/// let x = Tensor::<i128>::new(
/// Some(&[0, 1, 2, 3, 4, 5, 6, 7]),
/// &[2, 2, 2],
/// ).unwrap();
/// let index = Tensor::<usize>::new(
/// Some(&[0, 1, 1, 0]),
/// &[2, 2],
/// ).unwrap();
/// let result = gather_nd(&x, &index, 0).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[2, 3, 4, 5]), &[2, 2]).unwrap();
/// assert_eq!(result, expected);
///
/// let index = Tensor::<usize>::new(
/// Some(&[0, 1, 1, 0]),
/// &[2, 1, 2],
/// ).unwrap();
/// let result = gather_nd(&x, &index, 0).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[2, 3, 4, 5]), &[2, 1, 2]).unwrap();
/// assert_eq!(result, expected);
///
/// let index = Tensor::<usize>::new(
/// Some(&[1, 0]),
/// &[2, 1],
/// ).unwrap();
/// let result = gather_nd(&x, &index, 1).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[2, 3, 4, 5]), &[2, 2]).unwrap();
/// assert_eq!(result, expected);
///
/// let index = Tensor::<usize>::new(
/// Some(&[0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1]),
/// &[2, 2, 3],
/// ).unwrap();
/// let result = gather_nd(&x, &index, 0).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[2, 3, 4, 5]), &[2, 2]).unwrap();
/// assert_eq!(result, expected);
///
/// let index = Tensor::<usize>::new(
/// Some(&[0, 1, 0, 0, 1, 1, 1, 0]),
/// &[2, 2, 2],
/// ).unwrap();
/// let result = gather_nd(&x, &index, 0).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[2, 3, 0, 1, 6, 7, 4, 5]), &[2, 2, 2]).unwrap();
/// assert_eq!(result, expected);
///
/// let index = Tensor::<usize>::new(
/// Some(&[0, 1, 0, 1, 1, 1]),
/// &[2, 3],
/// ).unwrap();
/// let result = gather_nd(&x, &index, 0).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[2, 7]), &[2]).unwrap();
/// assert_eq!(result, expected);
///
pub fn gather_nd<T: TensorType + Send + Sync>(
input: &Tensor<T>,
index: &Tensor<usize>,
batch_dims: usize,
) -> Result<Tensor<T>, TensorError> {
// Calculate the output tensor size
let index_dims = index.dims().to_vec();
let input_dims = input.dims().to_vec();
let last_value = index_dims
.last()
.ok_or(TensorError::DimMismatch("gather_nd".to_string()))?;
if last_value > &(input_dims.len() - batch_dims) {
return Err(TensorError::DimMismatch("gather_nd".to_string()));
}
let output_size =
// If indices_shape[-1] == r-b, since the rank of indices is q,
// indices can be thought of as N (q-b-1)-dimensional tensors containing 1-D tensors of dimension r-b,
// where N is an integer equals to the product of 1 and all the elements in the batch dimensions of the indices_shape.
// Let us think of each such r-b ranked tensor as indices_slice.
// Each scalar value corresponding to data[0:b-1,indices_slice] is filled into
// the corresponding location of the (q-b-1)-dimensional tensor to form the output tensor
// if indices_shape[-1] < r-b, since the rank of indices is q, indices can be thought of as N (q-b-1)-dimensional tensor containing 1-D tensors of dimension < r-b.
// Let us think of each such tensors as indices_slice.
// Each tensor slice corresponding to data[0:b-1, indices_slice , :] is filled into the corresponding location of the (q-b-1)-dimensional tensor to form the output tensor
{
let output_rank = input_dims.len() + index_dims.len() - 1 - batch_dims - last_value;
let mut dims = index_dims[..index_dims.len() - 1].to_vec();
let input_offset = batch_dims + last_value;
dims.extend(input_dims[input_offset..input_dims.len()].to_vec());
assert_eq!(output_rank, dims.len());
dims
};
// cartesian coord over batch dims
let mut batch_cartesian_coord = input_dims[0..batch_dims]
.iter()
.map(|x| 0..*x)
.multi_cartesian_product()
.collect::<Vec<_>>();
if batch_cartesian_coord.is_empty() {
batch_cartesian_coord.push(vec![]);
}
let outputs = batch_cartesian_coord
.par_iter()
.map(|batch_coord| {
let batch_slice = batch_coord.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
let mut index_slice = index.get_slice(&batch_slice)?;
index_slice.reshape(&index.dims()[batch_dims..])?;
let mut input_slice = input.get_slice(&batch_slice)?;
input_slice.reshape(&input.dims()[batch_dims..])?;
let mut inner_cartesian_coord = index_slice.dims()[0..index_slice.dims().len() - 1]
.iter()
.map(|x| 0..*x)
.multi_cartesian_product()
.collect::<Vec<_>>();
if inner_cartesian_coord.is_empty() {
inner_cartesian_coord.push(vec![]);
}
let output = inner_cartesian_coord
.iter()
.map(|coord| {
let slice = coord
.iter()
.map(|x| *x..*x + 1)
.chain(batch_coord.iter().map(|x| *x..*x + 1))
.collect::<Vec<_>>();
let index_slice = index_slice
.get_slice(&slice)
.unwrap()
.iter()
.map(|x| *x..*x + 1)
.collect::<Vec<_>>();
input_slice.get_slice(&index_slice).unwrap()
})
.collect::<Tensor<_>>();
output.combine()
})
.collect::<Result<Vec<_>, _>>()?;
let mut outputs = outputs.into_iter().flatten().collect::<Tensor<_>>();
outputs.reshape(&output_size)?;
Ok(outputs)
}
/// Scatter ND.
/// This operator is the inverse of GatherND.
/// # Arguments
/// * `input` - Tensor
/// * `index` - Tensor of indices to scatter
/// * `src` - Tensor of src
/// # Examples
/// ```
/// use ezkl::tensor::Tensor;
/// use ezkl::tensor::ops::scatter_nd;
/// let x = Tensor::<i128>::new(
/// Some(&[1, 2, 3, 4, 5, 6, 7, 8]),
/// &[8],
/// ).unwrap();
///
/// let index = Tensor::<usize>::new(
/// Some(&[4, 3, 1, 7]),
/// &[4, 1],
/// ).unwrap();
/// let src = Tensor::<i128>::new(
/// Some(&[9, 10, 11, 12]),
/// &[4],
/// ).unwrap();
/// let result = scatter_nd(&x, &index, &src).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[1, 11, 3, 10, 9, 6, 7, 12]), &[8]).unwrap();
/// assert_eq!(result, expected);
///
/// let x = Tensor::<i128>::new(
/// Some(&[1, 2, 3, 4, 5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1,
/// 1, 2, 3, 4, 5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1,
/// 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 7, 8,
/// 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 7, 8]),
/// &[4, 4, 4],
/// ).unwrap();
///
/// let index = Tensor::<usize>::new(
/// Some(&[0, 2]),
/// &[2, 1],
/// ).unwrap();
///
/// let src = Tensor::<i128>::new(
/// Some(&[5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8,
/// 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4,
/// ]),
/// &[2, 4, 4],
/// ).unwrap();
///
/// let result = scatter_nd(&x, &index, &src).unwrap();
///
/// let expected = Tensor::<i128>::new(
/// Some(&[5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8,
/// 1, 2, 3, 4, 5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1,
/// 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4,
/// 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 7, 8]),
/// &[4, 4, 4],
/// ).unwrap();
/// assert_eq!(result, expected);
///
/// let x = Tensor::<i128>::new(
/// Some(&[1, 2, 3, 4, 5, 6, 7, 8]),
/// &[2, 4],
/// ).unwrap();
///
/// let index = Tensor::<usize>::new(
/// Some(&[0, 1]),
/// &[2, 1],
/// ).unwrap();
/// let src = Tensor::<i128>::new(
/// Some(&[9, 10]),
/// &[2],
/// ).unwrap();
/// let result = scatter_nd(&x, &index, &src).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[9, 9, 9, 9, 10, 10, 10, 10]), &[2, 4]).unwrap();
/// assert_eq!(result, expected);
///
/// let x = Tensor::<i128>::new(
/// Some(&[1, 2, 3, 4, 5, 6, 7, 8]),
/// &[2, 4],
/// ).unwrap();
///
/// let index = Tensor::<usize>::new(
/// Some(&[0, 1]),
/// &[1, 1, 2],
/// ).unwrap();
/// let src = Tensor::<i128>::new(
/// Some(&[9]),
/// &[1, 1],
/// ).unwrap();
/// let result = scatter_nd(&x, &index, &src).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[1, 9, 3, 4, 5, 6, 7, 8]), &[2, 4]).unwrap();
/// assert_eq!(result, expected);
/// ````
///
pub fn scatter_nd<T: TensorType + Send + Sync>(
input: &Tensor<T>,
index: &Tensor<usize>,
src: &Tensor<T>,
) -> Result<Tensor<T>, TensorError> {
// Calculate the output tensor size
let index_dims = index.dims().to_vec();
let input_dims = input.dims().to_vec();
let last_value = index_dims
.last()
.ok_or(TensorError::DimMismatch("scatter_nd".to_string()))?;
if last_value > &input_dims.len() {
return Err(TensorError::DimMismatch("scatter_nd".to_string()));
}
let mut output = input.clone();
let cartesian_coord = index_dims[0..index_dims.len() - 1]
.iter()
.map(|x| 0..*x)
.multi_cartesian_product()
.collect::<Vec<_>>();
cartesian_coord
.iter()
.map(|coord| {
let slice = coord.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
let index_val = index.get_slice(&slice)?;
let index_slice = index_val.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
let src_val = src.get_slice(&slice)?;
output.set_slice(&index_slice, &src_val)?;
Ok(())
})
.collect::<Result<Vec<_>, _>>()?;
Ok(output)
}
fn axes_op<T: TensorType + Send + Sync>(
a: &Tensor<T>,
axes: &[usize],

View File

@@ -193,7 +193,7 @@ mod native_tests {
"1l_tiny_div",
];
const TESTS: [&str; 77] = [
const TESTS: [&str; 79] = [
"1l_mlp", //0
"1l_slice",
"1l_concat",
@@ -275,6 +275,8 @@ mod native_tests {
"ltsf",
"remainder", //75
"bitshift",
"gather_nd",
"scatter_nd",
];
const WASM_TESTS: [&str; 46] = [
@@ -502,7 +504,7 @@ mod native_tests {
}
});
seq!(N in 0..=76 {
seq!(N in 0..=78 {
#(#[test_case(TESTS[N])])*
#[ignore]
@@ -589,13 +591,16 @@ mod native_tests {
#(#[test_case(TESTS[N])])*
fn mock_large_batch_public_outputs_(test: &str) {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
let large_batch_dir = &format!("large_batches_{}", test);
crate::native_tests::mk_data_batches_(path, test, &large_batch_dir, 10);
mock(path, large_batch_dir.to_string(), "private", "private", "public", 10, "resources", None, 0.0);
test_dir.close().unwrap();
// currently variable output rank is not supported in ONNX
if test != "gather_nd" {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
let large_batch_dir = &format!("large_batches_{}", test);
crate::native_tests::mk_data_batches_(path, test, &large_batch_dir, 10);
mock(path, large_batch_dir.to_string(), "private", "private", "public", 10, "resources", None, 0.0);
test_dir.close().unwrap();
}
}
#(#[test_case(TESTS[N])])*

Binary file not shown.

Binary file not shown.

File diff suppressed because one or more lines are too long

Binary file not shown.