mirror of
https://github.com/zkonduit/ezkl.git
synced 2026-01-13 00:08:12 -05:00
Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
de9e3f2673 | ||
|
|
a1450f8df7 | ||
|
|
ea535e2ecd | ||
|
|
f8aa91ed08 |
48
examples/onnx/gather_nd/gen.py
Normal file
48
examples/onnx/gather_nd/gen.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from torch import nn
|
||||
import json
|
||||
import numpy as np
|
||||
import tf2onnx
|
||||
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow.keras.layers import *
|
||||
from tensorflow.keras.models import Model
|
||||
|
||||
|
||||
# gather_nd in tf then export to onnx
|
||||
|
||||
|
||||
|
||||
|
||||
x = in1 = Input((15, 18,))
|
||||
w = in2 = Input((15, 1), dtype=tf.int32)
|
||||
x = tf.gather_nd(x, w, batch_dims=1)
|
||||
tm = Model((in1, in2), x )
|
||||
tm.summary()
|
||||
tm.compile(optimizer='adam', loss='mse')
|
||||
|
||||
shape = [1, 15, 18]
|
||||
index_shape = [1, 15, 1]
|
||||
# After training, export to onnx (network.onnx) and create a data file (input.json)
|
||||
x = 0.1*np.random.rand(1,*shape)
|
||||
# w = random int tensor
|
||||
w = np.random.randint(0, 10, index_shape)
|
||||
|
||||
spec = tf.TensorSpec(shape, tf.float32, name='input_0')
|
||||
index_spec = tf.TensorSpec(index_shape, tf.int32, name='input_1')
|
||||
|
||||
model_path = "network.onnx"
|
||||
|
||||
tf2onnx.convert.from_keras(tm, input_signature=[spec, index_spec], inputs_as_nchw=['input_0', 'input_1'], opset=12, output_path=model_path)
|
||||
|
||||
|
||||
d = x.reshape([-1]).tolist()
|
||||
d1 = w.reshape([-1]).tolist()
|
||||
|
||||
|
||||
data = dict(
|
||||
input_data=[d, d1],
|
||||
)
|
||||
|
||||
# Serialize data into file:
|
||||
json.dump(data, open("input.json", 'w'))
|
||||
1
examples/onnx/gather_nd/input.json
Normal file
1
examples/onnx/gather_nd/input.json
Normal file
File diff suppressed because one or more lines are too long
BIN
examples/onnx/gather_nd/network.onnx
Normal file
BIN
examples/onnx/gather_nd/network.onnx
Normal file
Binary file not shown.
76
examples/onnx/scatter_nd/gen.py
Normal file
76
examples/onnx/scatter_nd/gen.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import sys
|
||||
import json
|
||||
|
||||
sys.path.append("..")
|
||||
|
||||
class Model(nn.Module):
|
||||
"""
|
||||
Just one Linear layer
|
||||
"""
|
||||
def __init__(self, configs):
|
||||
super(Model, self).__init__()
|
||||
self.seq_len = configs.seq_len
|
||||
self.pred_len = configs.pred_len
|
||||
|
||||
# Use this line if you want to visualize the weights
|
||||
# self.Linear.weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len]))
|
||||
self.channels = configs.enc_in
|
||||
self.individual = configs.individual
|
||||
if self.individual:
|
||||
self.Linear = nn.ModuleList()
|
||||
for i in range(self.channels):
|
||||
self.Linear.append(nn.Linear(self.seq_len,self.pred_len))
|
||||
else:
|
||||
self.Linear = nn.Linear(self.seq_len, self.pred_len)
|
||||
|
||||
def forward(self, x):
|
||||
# x: [Batch, Input length, Channel]
|
||||
if self.individual:
|
||||
output = torch.zeros([x.size(0),self.pred_len,x.size(2)],dtype=x.dtype).to(x.device)
|
||||
for i in range(self.channels):
|
||||
output[:,:,i] = self.Linear[i](x[:,:,i])
|
||||
x = output
|
||||
else:
|
||||
x = self.Linear(x.permute(0,2,1)).permute(0,2,1)
|
||||
return x # [Batch, Output length, Channel]
|
||||
|
||||
class Configs:
|
||||
def __init__(self, seq_len, pred_len, enc_in=321, individual=True):
|
||||
self.seq_len = seq_len
|
||||
self.pred_len = pred_len
|
||||
self.enc_in = enc_in
|
||||
self.individual = individual
|
||||
|
||||
model = 'Linear'
|
||||
seq_len = 10
|
||||
pred_len = 4
|
||||
enc_in = 3
|
||||
|
||||
configs = Configs(seq_len, pred_len, enc_in, True)
|
||||
circuit = Model(configs)
|
||||
|
||||
x = torch.randn(1, seq_len, pred_len)
|
||||
|
||||
|
||||
torch.onnx.export(circuit, x, "network.onnx",
|
||||
export_params=True, # store the trained parameter weights inside the model file
|
||||
opset_version=15, # the ONNX version to export the model to
|
||||
do_constant_folding=True, # whether to execute constant folding for optimization
|
||||
# the model's input names
|
||||
input_names=['input'],
|
||||
output_names=['output'], # the model's output names
|
||||
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
|
||||
'output': {0: 'batch_size'}})
|
||||
|
||||
|
||||
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
|
||||
|
||||
|
||||
data = dict(
|
||||
input_data=[d1],
|
||||
)
|
||||
|
||||
# Serialize data into file:
|
||||
json.dump(data, open("input.json", 'w'))
|
||||
1
examples/onnx/scatter_nd/input.json
Normal file
1
examples/onnx/scatter_nd/input.json
Normal file
@@ -0,0 +1 @@
|
||||
{"input_data": [[0.1874287724494934, 1.0498261451721191, 0.22384068369865417, 1.048445224761963, -0.5670360326766968, -0.38653188943862915, 0.12878702580928802, -2.3675858974456787, 0.5800458192825317, -0.43653929233551025, -0.2511898875236511, 0.3324051797389984, 0.27960312366485596, 0.4763695001602173, 0.3796705901622772, 1.1334782838821411, -0.87981778383255, -1.2451434135437012, 0.7672272324562073, -0.24404007196426392, -0.6875824928283691, 0.3619358539581299, -0.10131897777318954, 0.7169521450996399, 1.6585893630981445, -0.5451845526695251, 0.429487019777298, 0.7426952123641968, -0.2543637454509735, 0.06546942889690399, 0.7939824461936951, 0.1579471379518509, -0.043604474514722824, -0.8621711730957031, -0.5344759821891785, -0.05880478024482727, -0.17351101338863373, 0.5095029473304749, -0.7864817976951599, -0.449171245098114]]}
|
||||
BIN
examples/onnx/scatter_nd/network.onnx
Normal file
BIN
examples/onnx/scatter_nd/network.onnx
Normal file
Binary file not shown.
@@ -646,14 +646,13 @@ fn select<F: PrimeField + TensorType + PartialOrd>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 2],
|
||||
dim_indices: ValTensor<F>,
|
||||
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
let (mut input, index) = (values[0].clone(), values[1].clone());
|
||||
input.flatten();
|
||||
|
||||
if !(dim_indices.all_prev_assigned() || region.is_dummy()) {
|
||||
return Err("dim_indices must be assigned".into());
|
||||
}
|
||||
// these will be assigned as constants
|
||||
let dim_indices: ValTensor<F> =
|
||||
Tensor::from((0..input.len() as u64).map(|x| ValType::Constant(F::from(x)))).into();
|
||||
|
||||
let is_assigned = !input.any_unknowns()? && !index.any_unknowns()?;
|
||||
|
||||
@@ -929,91 +928,25 @@ pub(crate) fn gather<F: PrimeField + TensorType + PartialOrd>(
|
||||
values: &[ValTensor<F>; 2],
|
||||
dim: usize,
|
||||
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
let (mut input, mut index_clone) = (values[0].clone(), values[1].clone());
|
||||
let (input, mut index_clone) = (values[0].clone(), values[1].clone());
|
||||
index_clone.flatten();
|
||||
if index_clone.is_singleton() {
|
||||
index_clone.reshape(&[1])?;
|
||||
}
|
||||
|
||||
let mut assigned_len = vec![];
|
||||
if !input.all_prev_assigned() {
|
||||
input = region.assign(&config.custom_gates.inputs[0], &input)?;
|
||||
assigned_len.push(input.len());
|
||||
}
|
||||
if !index_clone.all_prev_assigned() {
|
||||
index_clone = region.assign(&config.custom_gates.inputs[1], &index_clone)?;
|
||||
assigned_len.push(index_clone.len());
|
||||
}
|
||||
|
||||
if !assigned_len.is_empty() {
|
||||
// safe to unwrap since we've just checked it has at least one element
|
||||
region.increment(*assigned_len.iter().max().unwrap());
|
||||
}
|
||||
|
||||
// Calculate the output tensor size
|
||||
let input_dims = input.dims();
|
||||
let mut output_size = input_dims.to_vec();
|
||||
|
||||
output_size[dim] = index_clone.dims()[0];
|
||||
|
||||
// these will be assigned as constants
|
||||
let mut indices = Tensor::from((0..input.dims()[dim] as u64).map(|x| F::from(x)));
|
||||
indices.set_visibility(&crate::graph::Visibility::Fixed);
|
||||
let indices = region.assign(&config.custom_gates.inputs[1], &indices.try_into()?)?;
|
||||
region.increment(indices.len());
|
||||
let linear_index =
|
||||
linearize_element_index(config, region, &[index_clone], input_dims, dim, true)?;
|
||||
|
||||
let mut iteration_dims = output_size.clone();
|
||||
iteration_dims[dim] = 1;
|
||||
let mut output = select(config, region, &[input, linear_index])?;
|
||||
|
||||
// Allocate memory for the output tensor
|
||||
let cartesian_coord = iteration_dims
|
||||
.iter()
|
||||
.map(|x| 0..*x)
|
||||
.multi_cartesian_product()
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let mut results = HashMap::new();
|
||||
|
||||
for coord in cartesian_coord {
|
||||
let mut slice = coord.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
|
||||
slice[dim] = 0..input_dims[dim];
|
||||
|
||||
let mut sliced_input = input.get_slice(&slice)?;
|
||||
sliced_input.flatten();
|
||||
|
||||
let res = select(
|
||||
config,
|
||||
region,
|
||||
&[sliced_input, index_clone.clone()],
|
||||
indices.clone(),
|
||||
)?;
|
||||
|
||||
results.insert(coord, res);
|
||||
}
|
||||
|
||||
// Allocate memory for the output tensor
|
||||
let cartesian_coord = output_size
|
||||
.iter()
|
||||
.map(|x| 0..*x)
|
||||
.multi_cartesian_product()
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let mut output = Tensor::new(None, &output_size)?.par_enum_map(|i, _: ValType<F>| {
|
||||
let coord = cartesian_coord[i].clone();
|
||||
let mut key = coord.clone();
|
||||
key[dim] = 0;
|
||||
let result = &results.get(&key).ok_or("missing result")?;
|
||||
let o = result.get_inner_tensor().map_err(|_| "missing tensor")?[coord[dim]].clone();
|
||||
Ok::<ValType<F>, region::RegionError>(o)
|
||||
})?;
|
||||
|
||||
// Reshape the output tensor
|
||||
if index_clone.is_singleton() {
|
||||
output_size.remove(dim);
|
||||
}
|
||||
output.reshape(&output_size)?;
|
||||
|
||||
Ok(output.into())
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
/// Gather accumulated layout
|
||||
@@ -1022,82 +955,387 @@ pub(crate) fn gather_elements<F: PrimeField + TensorType + PartialOrd>(
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 2],
|
||||
dim: usize,
|
||||
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
let (mut input, mut index) = (values[0].clone(), values[1].clone());
|
||||
) -> Result<(ValTensor<F>, ValTensor<F>), Box<dyn Error>> {
|
||||
let (input, index) = (values[0].clone(), values[1].clone());
|
||||
|
||||
assert_eq!(input.dims().len(), index.dims().len());
|
||||
|
||||
if !input.all_prev_assigned() {
|
||||
input = region.assign(&config.custom_gates.inputs[0], &input)?;
|
||||
}
|
||||
if !index.all_prev_assigned() {
|
||||
index = region.assign(&config.custom_gates.inputs[1], &index)?;
|
||||
}
|
||||
|
||||
region.increment(std::cmp::max(input.len(), index.len()));
|
||||
|
||||
// Calculate the output tensor size
|
||||
let input_dims = input.dims();
|
||||
let output_size = index.dims().to_vec();
|
||||
|
||||
// these will be assigned as constants
|
||||
let mut indices = Tensor::from((0..input_dims[dim] as u64).map(|x| F::from(x)));
|
||||
indices.set_visibility(&crate::graph::Visibility::Fixed);
|
||||
let indices = region.assign(&config.custom_gates.inputs[1], &indices.try_into()?)?;
|
||||
region.increment(indices.len());
|
||||
let linear_index = linearize_element_index(config, region, &[index], input.dims(), dim, false)?;
|
||||
|
||||
let mut iteration_dims = output_size.clone();
|
||||
iteration_dims[dim] = 1;
|
||||
let mut output = select(config, region, &[input, linear_index.clone()])?;
|
||||
|
||||
output.reshape(&output_size)?;
|
||||
|
||||
Ok((output, linear_index))
|
||||
}
|
||||
|
||||
/// Gather accumulated layout
|
||||
pub(crate) fn gather_nd<F: PrimeField + TensorType + PartialOrd>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 2],
|
||||
batch_dims: usize,
|
||||
) -> Result<(ValTensor<F>, ValTensor<F>), Box<dyn Error>> {
|
||||
let (input, index) = (values[0].clone(), values[1].clone());
|
||||
|
||||
let index_dims = index.dims().to_vec();
|
||||
let input_dims = input.dims().to_vec();
|
||||
let last_value = index_dims
|
||||
.last()
|
||||
.ok_or(TensorError::DimMismatch("gather_nd".to_string()))?;
|
||||
if index_dims.last() > Some(&(input_dims.len() - batch_dims)) {
|
||||
return Err(TensorError::DimMismatch("gather_nd".to_string()).into());
|
||||
}
|
||||
|
||||
let output_size =
|
||||
// If indices_shape[-1] == r-b, since the rank of indices is q,
|
||||
// indices can be thought of as N (q-b-1)-dimensional tensors containing 1-D tensors of dimension r-b,
|
||||
// where N is an integer equals to the product of 1 and all the elements in the batch dimensions of the indices_shape.
|
||||
// Let us think of each such r-b ranked tensor as indices_slice.
|
||||
// Each scalar value corresponding to data[0:b-1,indices_slice] is filled into
|
||||
// the corresponding location of the (q-b-1)-dimensional tensor to form the output tensor
|
||||
// if indices_shape[-1] < r-b, since the rank of indices is q, indices can be thought of as N (q-b-1)-dimensional tensor containing 1-D tensors of dimension < r-b.
|
||||
// Let us think of each such tensors as indices_slice.
|
||||
// Each tensor slice corresponding to data[0:b-1, indices_slice , :] is filled into the corresponding location of the (q-b-1)-dimensional tensor to form the output tensor
|
||||
{
|
||||
let output_rank = input_dims.len() + index_dims.len() - 1 - batch_dims - last_value;
|
||||
|
||||
let mut dims = index_dims[..index_dims.len() - 1].to_vec();
|
||||
let input_offset = batch_dims + last_value;
|
||||
dims.extend(input_dims[input_offset..input_dims.len()].to_vec());
|
||||
|
||||
assert_eq!(output_rank, dims.len());
|
||||
dims
|
||||
|
||||
};
|
||||
|
||||
let linear_index = linearize_nd_index(config, region, &[index], input.dims(), batch_dims)?;
|
||||
|
||||
let mut output = select(config, region, &[input, linear_index.clone()])?;
|
||||
|
||||
output.reshape(&output_size)?;
|
||||
|
||||
Ok((output, linear_index))
|
||||
}
|
||||
|
||||
/// Takes a tensor representing a multi-dimensional index and returns a tensor representing the linearized index.
|
||||
/// The linearized index is the index of the element in the flattened tensor.
|
||||
/// FOr instance if the dims is [3,5,2], the linearized index of [2] at dim 1 is 2*5 + 3 = 13
|
||||
pub(crate) fn linearize_element_index<F: PrimeField + TensorType + PartialOrd>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
dims: &[usize],
|
||||
dim: usize,
|
||||
is_flat_index: bool,
|
||||
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
let index = values[0].clone();
|
||||
if !is_flat_index {
|
||||
assert_eq!(index.dims().len(), dims.len());
|
||||
// if the index is already flat, return it
|
||||
if index.dims().len() == 1 {
|
||||
return Ok(index);
|
||||
}
|
||||
}
|
||||
|
||||
let dim_multiplier: Tensor<usize> = Tensor::new(None, &[dims.len()])?;
|
||||
|
||||
let dim_multiplier: Tensor<F> = dim_multiplier.par_enum_map(|i, _| {
|
||||
let mut res = 1;
|
||||
for dim in dims.iter().skip(i + 1) {
|
||||
res *= dim;
|
||||
}
|
||||
|
||||
Ok::<_, region::RegionError>(F::from(res as u64))
|
||||
})?;
|
||||
|
||||
let iteration_dims = if is_flat_index {
|
||||
let mut dims = dims.to_vec();
|
||||
dims[dim] = index.len();
|
||||
dims
|
||||
} else {
|
||||
index.dims().to_vec()
|
||||
};
|
||||
|
||||
// Allocate memory for the output tensor
|
||||
let cartesian_coord = iteration_dims
|
||||
.iter()
|
||||
.map(|x| 0..*x)
|
||||
.multi_cartesian_product()
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let mut results = HashMap::new();
|
||||
let val_dim_multiplier: ValTensor<F> = dim_multiplier
|
||||
.get_slice(&[dim..dim + 1])?
|
||||
.map(|x| ValType::Constant(x))
|
||||
.into();
|
||||
|
||||
for coord in cartesian_coord {
|
||||
let mut slice = coord.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
|
||||
slice[dim] = 0..input_dims[dim];
|
||||
let mut output = Tensor::new(None, &[cartesian_coord.len()])?;
|
||||
|
||||
let mut sliced_input = input.get_slice(&slice)?;
|
||||
sliced_input.flatten();
|
||||
let inner_loop_function = |i: usize, region: &mut RegionCtx<'_, F>| {
|
||||
let coord = cartesian_coord[i].clone();
|
||||
let slice: Vec<Range<usize>> = if is_flat_index {
|
||||
coord[dim..dim + 1].iter().map(|x| *x..*x + 1).collect()
|
||||
} else {
|
||||
coord.iter().map(|x| *x..*x + 1).collect::<Vec<_>>()
|
||||
};
|
||||
|
||||
slice[dim] = 0..output_size[dim];
|
||||
let mut sliced_index = index.get_slice(&slice)?;
|
||||
sliced_index.flatten();
|
||||
let index_val = index.get_slice(&slice)?;
|
||||
|
||||
let res = select(
|
||||
let mut const_offset = F::ZERO;
|
||||
for i in 0..dims.len() {
|
||||
if i != dim {
|
||||
const_offset += F::from(coord[i] as u64) * dim_multiplier[i];
|
||||
}
|
||||
}
|
||||
let const_offset = create_constant_tensor(const_offset, 1);
|
||||
|
||||
let res = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[sliced_input, sliced_index],
|
||||
indices.clone(),
|
||||
&[index_val, val_dim_multiplier.clone()],
|
||||
BaseOp::Mult,
|
||||
)?;
|
||||
|
||||
results.insert(coord, res);
|
||||
}
|
||||
let res = pairwise(config, region, &[res, const_offset], BaseOp::Add)?;
|
||||
|
||||
// Allocate memory for the output tensor
|
||||
let cartesian_coord = output_size
|
||||
Ok(res.get_inner_tensor()?[0].clone())
|
||||
};
|
||||
|
||||
region.apply_in_loop(&mut output, inner_loop_function)?;
|
||||
|
||||
Ok(output.into())
|
||||
}
|
||||
|
||||
/// Takes a tensor representing a nd index and returns a tensor representing the linearized index.
|
||||
/// The linearized index is the index of the element in the flattened tensor.
|
||||
/// Given data tensor of rank r >= 1, indices tensor of rank q >= 1, and batch_dims integer b, this operator gathers slices of data into an output tensor of rank q + r - indices_shape[-1] - 1 - b.
|
||||
/// indices is an q-dimensional integer tensor, best thought of as a (q-1)-dimensional tensor of index-tuples into data, where each element defines a slice of data
|
||||
/// batch_dims (denoted as b) is an integer indicating the number of batch dimensions, i.e the leading b number of dimensions of data tensor and indices are representing the batches, and the gather starts from the b+1 dimension.
|
||||
/// Some salient points about the inputs’ rank and shape:
|
||||
/// r >= 1 and q >= 1 are to be honored. There is no dependency condition to be met between ranks r and q
|
||||
/// The first b dimensions of the shape of indices tensor and data tensor must be equal.
|
||||
/// b < min(q, r) is to be honored.
|
||||
/// The indices_shape[-1] should have a value between 1 (inclusive) and rank r-b (inclusive)
|
||||
/// All values in indices are expected to be within bounds [-s, s-1] along axis of size s (i.e.) -data_shape[i] <= indices[...,i] <= data_shape[i] - 1. It is an error if any of the index values are out of bounds.
|
||||
// The output is computed as follows:
|
||||
/// The output tensor is obtained by mapping each index-tuple in the indices tensor to the corresponding slice of the input data.
|
||||
/// If indices_shape[-1] > r-b => error condition
|
||||
/// If indices_shape[-1] == r-b, since the rank of indices is q, indices can be thought of as N (q-b-1)-dimensional tensors containing 1-D tensors of dimension r-b, where N is an integer equals to the product of 1 and all the elements in the batch dimensions of the indices_shape.
|
||||
/// Let us think of each such r-b ranked tensor as indices_slice. Each scalar value corresponding to data[0:b-1,indices_slice] is filled into the corresponding location of the (q-b-1)-dimensional tensor to form the output tensor (Example 1 below)
|
||||
/// If indices_shape[-1] < r-b, since the rank of indices is q, indices can be thought of as N (q-b-1)-dimensional tensor containing 1-D tensors of dimension < r-b. Let us think of each such tensors as indices_slice. Each tensor slice corresponding to data[0:b-1, indices_slice , :] is filled into the corresponding location of the (q-b-1)-dimensional tensor to form the output tensor (Examples 2, 3, 4 and 5 below)
|
||||
pub(crate) fn linearize_nd_index<F: PrimeField + TensorType + PartialOrd>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
dims: &[usize],
|
||||
batch_dims: usize,
|
||||
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
let index = values[0].clone();
|
||||
let index_dims = index.dims().to_vec();
|
||||
|
||||
let last_dim = index.dims().last().unwrap();
|
||||
let input_rank = dims[batch_dims..].len();
|
||||
|
||||
let dim_multiplier: Tensor<usize> = Tensor::new(None, &[dims.len()])?;
|
||||
let dim_multiplier: Tensor<F> = dim_multiplier.par_enum_map(|i, _| {
|
||||
let mut res = 1;
|
||||
for dim in dims.iter().skip(i + 1) {
|
||||
res *= dim;
|
||||
}
|
||||
Ok::<_, region::RegionError>(F::from(res as u64))
|
||||
})?;
|
||||
|
||||
let iteration_dims = index.dims()[0..batch_dims].to_vec();
|
||||
|
||||
let mut batch_cartesian_coord = iteration_dims
|
||||
.iter()
|
||||
.map(|x| 0..*x)
|
||||
.multi_cartesian_product()
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let output = Tensor::new(None, &output_size)?.par_enum_map(|i, _: ValType<F>| {
|
||||
let coord = cartesian_coord[i].clone();
|
||||
let mut key = coord.clone();
|
||||
key[dim] = 0;
|
||||
let result = &results.get(&key).ok_or("missing result")?;
|
||||
let o = result.get_inner_tensor().map_err(|_| "missing tensor")?[coord[dim]].clone();
|
||||
Ok::<ValType<F>, region::RegionError>(o)
|
||||
})?;
|
||||
if batch_cartesian_coord.is_empty() {
|
||||
batch_cartesian_coord.push(vec![]);
|
||||
}
|
||||
|
||||
let index_dim_multiplier: ValTensor<F> = dim_multiplier
|
||||
.get_slice(&[batch_dims..dims.len()])?
|
||||
.map(|x| ValType::Constant(x))
|
||||
.into();
|
||||
|
||||
let mut outer_results = vec![];
|
||||
|
||||
for coord in batch_cartesian_coord {
|
||||
let slice: Vec<Range<usize>> = coord.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
|
||||
|
||||
let mut index_slice = index.get_slice(&slice)?;
|
||||
index_slice.reshape(&index_dims[batch_dims..])?;
|
||||
|
||||
// expand the index to the full dims by iterating over the rest of the dims and inserting constants
|
||||
// eg in the case
|
||||
// batch_dims = 0
|
||||
// data = [[[0,1],[2,3]],[[4,5],[6,7]]] # data_shape = [2, 2, 2]
|
||||
// indices = [[0,1],[1,0]] # indices_shape = [2, 2]
|
||||
// output = [[2,3],[4,5]] # output_shape = [2, 2]
|
||||
// the index should be expanded to the shape [2,2,3]: [[0,1,0],[0,1,1],[1,0,0],[1,0,1]]
|
||||
|
||||
let mut inner_cartesian_coord = index_slice.dims()[0..index_slice.dims().len() - 1]
|
||||
.iter()
|
||||
.map(|x| 0..*x)
|
||||
.multi_cartesian_product()
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
if inner_cartesian_coord.is_empty() {
|
||||
inner_cartesian_coord.push(vec![]);
|
||||
}
|
||||
|
||||
let indices = if last_dim < &input_rank {
|
||||
inner_cartesian_coord
|
||||
.iter()
|
||||
.map(|x| {
|
||||
let slice = x.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
|
||||
let index = index_slice.get_slice(&slice)?;
|
||||
|
||||
// map over cartesian coord of rest of dims and insert constants
|
||||
let grid = (*last_dim..input_rank)
|
||||
.map(|x| 0..dims[x])
|
||||
.multi_cartesian_product();
|
||||
|
||||
Ok(grid
|
||||
.map(|x| {
|
||||
let index = index.clone();
|
||||
let constant_valtensor: ValTensor<F> = Tensor::from(
|
||||
x.into_iter().map(|x| ValType::Constant(F::from(x as u64))),
|
||||
)
|
||||
.into();
|
||||
index.concat(constant_valtensor)
|
||||
})
|
||||
.collect::<Result<Vec<_>, TensorError>>()?)
|
||||
})
|
||||
.collect::<Result<Vec<_>, Box<dyn Error>>>()?
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.collect::<Vec<_>>()
|
||||
} else {
|
||||
inner_cartesian_coord
|
||||
.iter()
|
||||
.map(|x| {
|
||||
let slice = x.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
|
||||
Ok(index_slice.get_slice(&slice)?)
|
||||
})
|
||||
.collect::<Result<Vec<_>, Box<dyn Error>>>()?
|
||||
};
|
||||
|
||||
let mut const_offset = F::ZERO;
|
||||
for i in 0..batch_dims {
|
||||
const_offset += F::from(coord[i] as u64) * dim_multiplier[i];
|
||||
}
|
||||
|
||||
|
||||
let const_offset = create_constant_tensor(const_offset, 1);
|
||||
|
||||
let mut results = vec![];
|
||||
|
||||
for index_val in indices {
|
||||
let mut index_val = index_val.clone();
|
||||
index_val.flatten();
|
||||
let res = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[index_val.clone(), index_dim_multiplier.clone()],
|
||||
BaseOp::Mult,
|
||||
)?;
|
||||
let res = res.concat(const_offset.clone())?;
|
||||
let res = sum(config, region, &[res])?;
|
||||
results.push(res.get_inner_tensor()?.clone());
|
||||
// assert than res is less than the product of the dims
|
||||
assert!(
|
||||
res.get_int_evals()?
|
||||
.iter()
|
||||
.all(|x| *x < dims.iter().product::<usize>() as i128),
|
||||
"res is greater than the product of the dims {} (coord={}, index_dim_multiplier={}, res={})",
|
||||
dims.iter().product::<usize>(),
|
||||
index_val.show(),
|
||||
index_dim_multiplier.show(),
|
||||
res.show()
|
||||
);
|
||||
}
|
||||
|
||||
let result_tensor = Tensor::from(results.into_iter());
|
||||
|
||||
outer_results.push(result_tensor.combine()?);
|
||||
}
|
||||
|
||||
let output = Tensor::from(outer_results.into_iter());
|
||||
let output = output.combine()?;
|
||||
|
||||
Ok(output.into())
|
||||
}
|
||||
|
||||
pub(crate) fn get_missing_set_elements<F: PrimeField + TensorType + PartialOrd>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 2],
|
||||
ordered: bool,
|
||||
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
let (mut input, fullset) = (values[0].clone(), values[1].clone());
|
||||
let set_len = fullset.len();
|
||||
input.flatten();
|
||||
|
||||
let is_assigned = !input.any_unknowns()? && !fullset.any_unknowns()?;
|
||||
|
||||
let mut claimed_output: ValTensor<F> = if is_assigned {
|
||||
let input_evals = input.get_int_evals()?;
|
||||
let mut fullset_evals = fullset.get_int_evals()?.into_iter().collect::<Vec<_>>();
|
||||
|
||||
// get the difference between the two vectors
|
||||
for eval in input_evals.iter() {
|
||||
// delete first occurence of that value
|
||||
if let Some(pos) = fullset_evals.iter().position(|x| x == eval) {
|
||||
fullset_evals.remove(pos);
|
||||
}
|
||||
}
|
||||
|
||||
// if fullset + input is the same length, then input is a subset of fullset, else randomly delete elements, this is a patch for
|
||||
// the fact that we can't have a tensor of unknowns when using constant during gen-settings
|
||||
if fullset_evals.len() != set_len - input.len() {
|
||||
fullset_evals.truncate(set_len - input.len());
|
||||
}
|
||||
|
||||
fullset_evals
|
||||
.iter()
|
||||
.map(|x| Value::known(i128_to_felt(*x)))
|
||||
.collect::<Tensor<Value<F>>>()
|
||||
.into()
|
||||
} else {
|
||||
let dim = fullset.len() - input.len();
|
||||
Tensor::new(Some(&vec![Value::<F>::unknown(); dim]), &[dim])?.into()
|
||||
};
|
||||
|
||||
// assign the claimed output
|
||||
claimed_output = region.assign(&config.custom_gates.output, &claimed_output)?;
|
||||
|
||||
// input and claimed output should be the shuffles of fullset
|
||||
// concatentate input and claimed output
|
||||
let input_and_claimed_output = input.concat(claimed_output.clone())?;
|
||||
|
||||
// assert that this is a permutation/shuffle
|
||||
shuffles(
|
||||
config,
|
||||
region,
|
||||
&[input_and_claimed_output.clone()],
|
||||
&[fullset.clone()],
|
||||
)?;
|
||||
|
||||
if ordered {
|
||||
// assert that the claimed output is sorted
|
||||
claimed_output = _sort_ascending(config, region, &[claimed_output])?;
|
||||
}
|
||||
|
||||
Ok(claimed_output)
|
||||
}
|
||||
|
||||
/// Gather accumulated layout
|
||||
pub(crate) fn scatter_elements<F: PrimeField + TensorType + PartialOrd>(
|
||||
config: &BaseConfig<F>,
|
||||
@@ -1105,88 +1343,158 @@ pub(crate) fn scatter_elements<F: PrimeField + TensorType + PartialOrd>(
|
||||
values: &[ValTensor<F>; 3],
|
||||
dim: usize,
|
||||
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
let (mut input, mut index, mut src) = (values[0].clone(), values[1].clone(), values[2].clone());
|
||||
let (input, mut index, src) = (values[0].clone(), values[1].clone(), values[2].clone());
|
||||
|
||||
assert_eq!(input.dims().len(), index.dims().len());
|
||||
|
||||
let mut assigned_len = vec![];
|
||||
|
||||
if !input.all_prev_assigned() {
|
||||
input = region.assign(&config.custom_gates.inputs[0], &input)?;
|
||||
assigned_len.push(input.len());
|
||||
}
|
||||
if !index.all_prev_assigned() {
|
||||
index = region.assign(&config.custom_gates.inputs[1], &index)?;
|
||||
assigned_len.push(index.len());
|
||||
}
|
||||
if !src.all_prev_assigned() {
|
||||
src = region.assign(&config.custom_gates.output, &src)?;
|
||||
assigned_len.push(src.len());
|
||||
region.increment(index.len());
|
||||
}
|
||||
|
||||
if !assigned_len.is_empty() {
|
||||
// safe to unwrap since we've just checked it has at least one element
|
||||
region.increment(*assigned_len.iter().max().unwrap());
|
||||
}
|
||||
let is_assigned = !input.any_unknowns()? && !index.any_unknowns()? && !src.any_unknowns()?;
|
||||
|
||||
// Calculate the output tensor size
|
||||
let input_dim = input.dims()[dim];
|
||||
let output_size = index.dims().to_vec();
|
||||
let claimed_output: ValTensor<F> = if is_assigned {
|
||||
let input_inner = input.get_int_evals()?;
|
||||
let index_inner = index.get_int_evals()?.map(|x| x as usize);
|
||||
let src_inner = src.get_int_evals()?;
|
||||
|
||||
// these will be assigned as constants
|
||||
let mut indices = Tensor::from((0..input_dim as u64).map(|x| F::from(x)));
|
||||
indices.set_visibility(&crate::graph::Visibility::Fixed);
|
||||
let indices = region.assign(&config.custom_gates.inputs[1], &indices.try_into()?)?;
|
||||
region.increment(indices.len());
|
||||
let res = tensor::ops::scatter(&input_inner, &index_inner, &src_inner, dim)?;
|
||||
|
||||
// Allocate memory for the output tensor
|
||||
let cartesian_coord = output_size
|
||||
.iter()
|
||||
.map(|x| 0..*x)
|
||||
.multi_cartesian_product()
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let mut output: Tensor<()> = Tensor::new(None, &output_size)?;
|
||||
|
||||
let mut inner_loop_function = |i: usize, region: &mut RegionCtx<'_, F>| {
|
||||
let coord = cartesian_coord[i].clone();
|
||||
let index_val = index.get_inner_tensor()?.get(&coord);
|
||||
|
||||
let src_val = src.get_inner_tensor()?.get(&coord);
|
||||
let src_valtensor: ValTensor<F> = Tensor::from([src_val.clone()].into_iter()).into();
|
||||
|
||||
let mut slice = coord.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
|
||||
slice[dim] = 0..input_dim;
|
||||
|
||||
let mut sliced_input = input.get_slice(&slice)?;
|
||||
sliced_input.flatten();
|
||||
|
||||
let index_valtensor: ValTensor<F> = Tensor::from([index_val.clone()].into_iter()).into();
|
||||
|
||||
let mask = equals(config, region, &[index_valtensor, indices.clone()])?;
|
||||
|
||||
let res = iff(config, region, &[mask, src_valtensor, sliced_input])?;
|
||||
|
||||
let input_cartesian_coord = slice.into_iter().multi_cartesian_product();
|
||||
let mutable_input_inner = input.get_inner_tensor_mut()?;
|
||||
|
||||
for (i, r) in res.get_inner_tensor()?.iter().enumerate() {
|
||||
let coord = input_cartesian_coord
|
||||
.clone()
|
||||
.nth(i)
|
||||
.ok_or("invalid coord")?;
|
||||
*mutable_input_inner.get_mut(&coord) = r.clone();
|
||||
}
|
||||
Ok(())
|
||||
res.iter()
|
||||
.map(|x| Value::known(i128_to_felt(*x)))
|
||||
.collect::<Tensor<Value<F>>>()
|
||||
.into()
|
||||
} else {
|
||||
Tensor::new(
|
||||
Some(&vec![Value::<F>::unknown(); input.len()]),
|
||||
&[input.len()],
|
||||
)?
|
||||
.into()
|
||||
};
|
||||
|
||||
output
|
||||
.iter_mut()
|
||||
.enumerate()
|
||||
.map(|(i, _)| inner_loop_function(i, region))
|
||||
.collect::<Result<Vec<()>, Box<dyn Error>>>()?;
|
||||
// assign the claimed output
|
||||
let mut claimed_output = region.assign(&config.custom_gates.output, &claimed_output)?;
|
||||
region.increment(claimed_output.len());
|
||||
claimed_output.reshape(input.dims())?;
|
||||
|
||||
Ok(input)
|
||||
// scatter elements is the inverse of gather elements
|
||||
let (gather_src, linear_index) = gather_elements(
|
||||
config,
|
||||
region,
|
||||
&[claimed_output.clone(), index.clone()],
|
||||
dim,
|
||||
)?;
|
||||
|
||||
// assert this is equal to the src
|
||||
enforce_equality(config, region, &[gather_src, src])?;
|
||||
|
||||
let full_index_set: ValTensor<F> =
|
||||
Tensor::from((0..input.len() as u64).map(|x| ValType::Constant(F::from(x)))).into();
|
||||
let input_indices = get_missing_set_elements(
|
||||
config,
|
||||
region,
|
||||
&[linear_index, full_index_set.clone()],
|
||||
true,
|
||||
)?;
|
||||
|
||||
claimed_output.flatten();
|
||||
let (gather_input, _) = gather_elements(
|
||||
config,
|
||||
region,
|
||||
&[claimed_output.clone(), input_indices.clone()],
|
||||
0,
|
||||
)?;
|
||||
// assert this is a subset of the input
|
||||
dynamic_lookup(
|
||||
config,
|
||||
region,
|
||||
&[input_indices, gather_input],
|
||||
&[full_index_set, input.clone()],
|
||||
)?;
|
||||
|
||||
claimed_output.reshape(input.dims())?;
|
||||
|
||||
Ok(claimed_output)
|
||||
}
|
||||
|
||||
/// Scatter Nd
|
||||
pub(crate) fn scatter_nd<F: PrimeField + TensorType + PartialOrd>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 3],
|
||||
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
let (input, mut index, src) = (values[0].clone(), values[1].clone(), values[2].clone());
|
||||
|
||||
if !index.all_prev_assigned() {
|
||||
index = region.assign(&config.custom_gates.inputs[1], &index)?;
|
||||
region.increment(index.len());
|
||||
}
|
||||
|
||||
let is_assigned = !input.any_unknowns()? && !index.any_unknowns()? && !src.any_unknowns()?;
|
||||
|
||||
let claimed_output: ValTensor<F> = if is_assigned {
|
||||
let input_inner = input.get_int_evals()?;
|
||||
let index_inner = index.get_int_evals()?.map(|x| x as usize);
|
||||
let src_inner = src.get_int_evals()?;
|
||||
|
||||
let res = tensor::ops::scatter_nd(&input_inner, &index_inner, &src_inner)?;
|
||||
|
||||
res.iter()
|
||||
.map(|x| Value::known(i128_to_felt(*x)))
|
||||
.collect::<Tensor<Value<F>>>()
|
||||
.into()
|
||||
} else {
|
||||
Tensor::new(
|
||||
Some(&vec![Value::<F>::unknown(); input.len()]),
|
||||
&[input.len()],
|
||||
)?
|
||||
.into()
|
||||
};
|
||||
|
||||
// assign the claimed output
|
||||
let mut claimed_output = region.assign(&config.custom_gates.output, &claimed_output)?;
|
||||
region.increment(claimed_output.len());
|
||||
claimed_output.reshape(input.dims())?;
|
||||
|
||||
|
||||
// scatter elements is the inverse of gather elements
|
||||
let (gather_src, linear_index) =
|
||||
gather_nd(config, region, &[claimed_output.clone(), index.clone()], 0)?;
|
||||
|
||||
// assert this is equal to the src
|
||||
enforce_equality(config, region, &[gather_src, src])?;
|
||||
|
||||
let full_index_set: ValTensor<F> =
|
||||
Tensor::from((0..input.len() as u64).map(|x| ValType::Constant(F::from(x)))).into();
|
||||
|
||||
let input_indices = get_missing_set_elements(
|
||||
config,
|
||||
region,
|
||||
&[linear_index, full_index_set.clone()],
|
||||
true,
|
||||
)?;
|
||||
|
||||
// now that it is flattened we can gather over elements on dim 0
|
||||
claimed_output.flatten();
|
||||
let (gather_input, _) = gather_elements(
|
||||
config,
|
||||
region,
|
||||
&[claimed_output.clone(), input_indices.clone()],
|
||||
0,
|
||||
)?;
|
||||
|
||||
// assert this is a subset of the input
|
||||
dynamic_lookup(
|
||||
config,
|
||||
region,
|
||||
&[input_indices, gather_input],
|
||||
&[full_index_set, input.clone()],
|
||||
)?;
|
||||
|
||||
claimed_output.reshape(input.dims())?;
|
||||
|
||||
Ok(claimed_output)
|
||||
}
|
||||
|
||||
/// sum accumulated layout
|
||||
@@ -1443,17 +1751,11 @@ pub(crate) fn argmax_axes<F: PrimeField + TensorType + PartialOrd>(
|
||||
dim: usize,
|
||||
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
// these will be assigned as constants
|
||||
let mut indices = Tensor::from((0..values[0].dims()[dim] as u64).map(|x| F::from(x)));
|
||||
indices.set_visibility(&crate::graph::Visibility::Fixed);
|
||||
let indices = region.assign(&config.custom_gates.inputs[1], &indices.try_into()?)?;
|
||||
region.increment(indices.len());
|
||||
|
||||
let argmax = move |config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1]|
|
||||
-> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
argmax(config, region, values, indices.clone())
|
||||
};
|
||||
let argmax =
|
||||
move |config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1]|
|
||||
-> Result<ValTensor<F>, Box<dyn Error>> { argmax(config, region, values) };
|
||||
|
||||
// calculate value of output
|
||||
axes_wise_op(config, region, values, &[dim], argmax)
|
||||
@@ -1479,18 +1781,12 @@ pub(crate) fn argmin_axes<F: PrimeField + TensorType + PartialOrd>(
|
||||
dim: usize,
|
||||
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
// calculate value of output
|
||||
// these will be assigned as constants
|
||||
let mut indices = Tensor::from((0..values[0].dims()[dim] as u64).map(|x| F::from(x)));
|
||||
indices.set_visibility(&crate::graph::Visibility::Fixed);
|
||||
let indices = region.assign(&config.custom_gates.inputs[1], &indices.try_into()?)?;
|
||||
region.increment(indices.len());
|
||||
|
||||
let argmin = move |config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1]|
|
||||
-> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
argmin(config, region, values, indices.clone())
|
||||
};
|
||||
let argmin =
|
||||
move |config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1]|
|
||||
-> Result<ValTensor<F>, Box<dyn Error>> { argmin(config, region, values) };
|
||||
|
||||
axes_wise_op(config, region, values, &[dim], argmin)
|
||||
}
|
||||
@@ -2665,7 +2961,6 @@ pub(crate) fn argmax<F: PrimeField + TensorType + PartialOrd>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
indices: ValTensor<F>,
|
||||
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
// this is safe because we later constrain it
|
||||
let argmax = values[0]
|
||||
@@ -2688,7 +2983,6 @@ pub(crate) fn argmax<F: PrimeField + TensorType + PartialOrd>(
|
||||
config,
|
||||
region,
|
||||
&[values[0].clone(), assigned_argmax.clone()],
|
||||
indices,
|
||||
)?;
|
||||
|
||||
let max_val = max(config, region, &[values[0].clone()])?;
|
||||
@@ -2703,7 +2997,6 @@ pub(crate) fn argmin<F: PrimeField + TensorType + PartialOrd>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
indices: ValTensor<F>,
|
||||
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
// this is safe because we later constrain it
|
||||
let argmin = values[0]
|
||||
@@ -2727,7 +3020,6 @@ pub(crate) fn argmin<F: PrimeField + TensorType + PartialOrd>(
|
||||
config,
|
||||
region,
|
||||
&[values[0].clone(), assigned_argmin.clone()],
|
||||
indices,
|
||||
)?;
|
||||
let min_val = min(config, region, &[values[0].clone()])?;
|
||||
|
||||
|
||||
@@ -14,10 +14,17 @@ pub enum PolyOp {
|
||||
dim: usize,
|
||||
constant_idx: Option<Tensor<usize>>,
|
||||
},
|
||||
GatherND {
|
||||
batch_dims: usize,
|
||||
indices: Option<Tensor<usize>>,
|
||||
},
|
||||
ScatterElements {
|
||||
dim: usize,
|
||||
constant_idx: Option<Tensor<usize>>,
|
||||
},
|
||||
ScatterND {
|
||||
constant_idx: Option<Tensor<usize>>,
|
||||
},
|
||||
MultiBroadcastTo {
|
||||
shape: Vec<usize>,
|
||||
},
|
||||
@@ -89,7 +96,9 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
|
||||
fn as_string(&self) -> String {
|
||||
match &self {
|
||||
PolyOp::GatherElements { dim, .. } => format!("GATHERELEMENTS (dim={})", dim),
|
||||
PolyOp::GatherND { batch_dims, .. } => format!("GATHERND (batch_dims={})", batch_dims),
|
||||
PolyOp::ScatterElements { dim, .. } => format!("SCATTERELEMENTS (dim={})", dim),
|
||||
PolyOp::ScatterND { .. } => "SCATTERND".into(),
|
||||
PolyOp::MultiBroadcastTo { shape } => format!("MULTIBROADCASTTO (shape={:?})", shape),
|
||||
PolyOp::MoveAxis { .. } => "MOVEAXIS".into(),
|
||||
PolyOp::Downsample { .. } => "DOWNSAMPLE".into(),
|
||||
@@ -213,6 +222,18 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
|
||||
};
|
||||
tensor::ops::gather_elements(&x, &y, *dim)
|
||||
}
|
||||
PolyOp::GatherND {
|
||||
indices,
|
||||
batch_dims,
|
||||
} => {
|
||||
let x = inputs[0].clone();
|
||||
let y = if let Some(idx) = indices {
|
||||
idx.clone()
|
||||
} else {
|
||||
inputs[1].clone().map(|x| felt_to_i128(x) as usize)
|
||||
};
|
||||
tensor::ops::gather_nd(&x, &y, *batch_dims)
|
||||
}
|
||||
PolyOp::ScatterElements { dim, constant_idx } => {
|
||||
let x = inputs[0].clone();
|
||||
|
||||
@@ -229,6 +250,21 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
|
||||
};
|
||||
tensor::ops::scatter(&x, &idx, &src, *dim)
|
||||
}
|
||||
|
||||
PolyOp::ScatterND { constant_idx } => {
|
||||
let x = inputs[0].clone();
|
||||
let idx = if let Some(idx) = constant_idx {
|
||||
idx.clone()
|
||||
} else {
|
||||
inputs[1].clone().map(|x| felt_to_i128(x) as usize)
|
||||
};
|
||||
let src = if constant_idx.is_some() {
|
||||
inputs[1].clone()
|
||||
} else {
|
||||
inputs[2].clone()
|
||||
};
|
||||
tensor::ops::scatter_nd(&x, &idx, &src)
|
||||
}
|
||||
}?;
|
||||
|
||||
Ok(ForwardResult { output: res })
|
||||
@@ -276,7 +312,17 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
|
||||
if let Some(idx) = constant_idx {
|
||||
tensor::ops::gather_elements(values[0].get_inner_tensor()?, idx, *dim)?.into()
|
||||
} else {
|
||||
layouts::gather_elements(config, region, values[..].try_into()?, *dim)?
|
||||
layouts::gather_elements(config, region, values[..].try_into()?, *dim)?.0
|
||||
}
|
||||
}
|
||||
PolyOp::GatherND {
|
||||
batch_dims,
|
||||
indices,
|
||||
} => {
|
||||
if let Some(idx) = indices {
|
||||
tensor::ops::gather_nd(values[0].get_inner_tensor()?, idx, *batch_dims)?.into()
|
||||
} else {
|
||||
layouts::gather_nd(config, region, values[..].try_into()?, *batch_dims)?.0
|
||||
}
|
||||
}
|
||||
PolyOp::ScatterElements { dim, constant_idx } => {
|
||||
@@ -292,6 +338,18 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
|
||||
layouts::scatter_elements(config, region, values[..].try_into()?, *dim)?
|
||||
}
|
||||
}
|
||||
PolyOp::ScatterND { constant_idx } => {
|
||||
if let Some(idx) = constant_idx {
|
||||
tensor::ops::scatter_nd(
|
||||
values[0].get_inner_tensor()?,
|
||||
idx,
|
||||
values[1].get_inner_tensor()?,
|
||||
)?
|
||||
.into()
|
||||
} else {
|
||||
layouts::scatter_nd(config, region, values[..].try_into()?)?
|
||||
}
|
||||
}
|
||||
PolyOp::DeConv {
|
||||
padding,
|
||||
output_padding,
|
||||
@@ -389,7 +447,9 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
|
||||
vec![1, 2]
|
||||
} else if matches!(self, PolyOp::Concat { .. }) {
|
||||
(0..100).collect()
|
||||
} else if matches!(self, PolyOp::ScatterElements { .. }) {
|
||||
} else if matches!(self, PolyOp::ScatterElements { .. })
|
||||
| matches!(self, PolyOp::ScatterND { .. })
|
||||
{
|
||||
vec![0, 2]
|
||||
} else {
|
||||
vec![]
|
||||
|
||||
@@ -402,9 +402,6 @@ pub enum Commands {
|
||||
/// Number of logrows to use for srs. Overrides settings_path if specified.
|
||||
#[arg(long, default_value = None)]
|
||||
logrows: Option<u32>,
|
||||
/// Check mode for SRS. Verifies downloaded srs is valid. Set to unsafe for speed.
|
||||
#[arg(long, default_value = DEFAULT_CHECKMODE)]
|
||||
check: CheckMode,
|
||||
},
|
||||
/// Loads model and input and runs mock prover (for testing)
|
||||
Mock {
|
||||
|
||||
@@ -159,8 +159,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
srs_path,
|
||||
settings_path,
|
||||
logrows,
|
||||
check,
|
||||
} => get_srs_cmd(srs_path, settings_path, logrows, check).await,
|
||||
} => get_srs_cmd(srs_path, settings_path, logrows).await,
|
||||
Commands::Table { model, args } => table(model, args),
|
||||
#[cfg(feature = "render")]
|
||||
Commands::RenderCircuit {
|
||||
@@ -492,23 +491,28 @@ async fn fetch_srs(uri: &str) -> Result<Vec<u8>, Box<dyn Error>> {
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
fn check_srs_hash(logrows: u32, srs_path: Option<PathBuf>) -> Result<String, Box<dyn Error>> {
|
||||
pub(crate) fn get_file_hash(path: &PathBuf) -> Result<String, Box<dyn Error>> {
|
||||
use std::io::Read;
|
||||
|
||||
let path = get_srs_path(logrows, srs_path);
|
||||
let file = std::fs::File::open(path.clone())?;
|
||||
let file = std::fs::File::open(path)?;
|
||||
let mut reader = std::io::BufReader::new(file);
|
||||
let mut buffer = vec![];
|
||||
let mut reader = std::io::BufReader::with_capacity(*EZKL_BUF_CAPACITY, file);
|
||||
let bytes_read = reader.read_to_end(&mut buffer)?;
|
||||
|
||||
info!(
|
||||
"read {} bytes from SRS file (vector of len = {})",
|
||||
"read {} bytes from file (vector of len = {})",
|
||||
bytes_read,
|
||||
buffer.len()
|
||||
);
|
||||
|
||||
let hash = sha256::digest(buffer);
|
||||
info!("SRS hash: {}", hash);
|
||||
info!("file hash: {}", hash);
|
||||
|
||||
Ok(hash)
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
fn check_srs_hash(logrows: u32, srs_path: Option<PathBuf>) -> Result<String, Box<dyn Error>> {
|
||||
let path = get_srs_path(logrows, srs_path);
|
||||
let hash = get_file_hash(&path)?;
|
||||
|
||||
let predefined_hash = match { crate::srs_sha::PUBLIC_SRS_SHA256_HASHES.get(&logrows) } {
|
||||
Some(h) => h,
|
||||
@@ -532,7 +536,6 @@ pub(crate) async fn get_srs_cmd(
|
||||
srs_path: Option<PathBuf>,
|
||||
settings_path: Option<PathBuf>,
|
||||
logrows: Option<u32>,
|
||||
check_mode: CheckMode,
|
||||
) -> Result<String, Box<dyn Error>> {
|
||||
// logrows overrides settings
|
||||
|
||||
@@ -560,21 +563,20 @@ pub(crate) async fn get_srs_cmd(
|
||||
let srs_uri = format!("{}{}", PUBLIC_SRS_URL, k);
|
||||
let mut reader = Cursor::new(fetch_srs(&srs_uri).await?);
|
||||
// check the SRS
|
||||
if matches!(check_mode, CheckMode::SAFE) {
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
let pb = init_spinner();
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pb.set_message("Validating SRS (this may take a while) ...");
|
||||
ParamsKZG::<Bn256>::read(&mut reader)?;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pb.finish_with_message("SRS validated");
|
||||
}
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
let pb = init_spinner();
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pb.set_message("Validating SRS (this may take a while) ...");
|
||||
let params = ParamsKZG::<Bn256>::read(&mut reader)?;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pb.finish_with_message("SRS validated.");
|
||||
|
||||
info!("Saving SRS to disk...");
|
||||
let mut file = std::fs::File::create(get_srs_path(k, srs_path.clone()))?;
|
||||
|
||||
let mut buffer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, &mut file);
|
||||
buffer.write_all(reader.get_ref())?;
|
||||
buffer.flush()?;
|
||||
params.write(&mut buffer)?;
|
||||
|
||||
info!("Saved SRS to disk.");
|
||||
|
||||
info!("SRS downloaded");
|
||||
} else {
|
||||
@@ -969,8 +971,8 @@ pub(crate) fn calibrate(
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// drop the gag
|
||||
|
||||
// drop the gag
|
||||
#[cfg(unix)]
|
||||
drop(_r);
|
||||
#[cfg(unix)]
|
||||
@@ -1695,7 +1697,7 @@ pub(crate) fn fuzz(
|
||||
let logrows = circuit.settings().run_args.logrows;
|
||||
|
||||
info!("setting up tests");
|
||||
|
||||
#[cfg(unix)]
|
||||
let _r = Gag::stdout()?;
|
||||
let params = gen_srs::<KZGCommitmentScheme<Bn256>>(logrows);
|
||||
|
||||
@@ -1713,6 +1715,7 @@ pub(crate) fn fuzz(
|
||||
let public_inputs = circuit.prepare_public_inputs(&data)?;
|
||||
|
||||
let strategy = KZGSingleStrategy::new(¶ms);
|
||||
#[cfg(unix)]
|
||||
std::mem::drop(_r);
|
||||
|
||||
info!("starting fuzzing");
|
||||
@@ -1903,6 +1906,7 @@ pub(crate) fn run_fuzz_fn(
|
||||
passed: &AtomicBool,
|
||||
) {
|
||||
let num_failures = AtomicI64::new(0);
|
||||
#[cfg(unix)]
|
||||
let _r = Gag::stdout().unwrap();
|
||||
|
||||
let pb = init_bar(num_runs as u64);
|
||||
@@ -1916,6 +1920,7 @@ pub(crate) fn run_fuzz_fn(
|
||||
pb.inc(1);
|
||||
});
|
||||
pb.finish_with_message("Done.");
|
||||
#[cfg(unix)]
|
||||
std::mem::drop(_r);
|
||||
info!(
|
||||
"num failures: {} out of {}",
|
||||
|
||||
@@ -23,7 +23,10 @@ use std::sync::Arc;
|
||||
use tract_onnx::prelude::{DatumType, Node as OnnxNode, TypedFact, TypedOp};
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use tract_onnx::tract_core::ops::{
|
||||
array::{Gather, GatherElements, MultiBroadcastTo, OneHot, ScatterElements, Slice, Topk},
|
||||
array::{
|
||||
Gather, GatherElements, GatherNd, MultiBroadcastTo, OneHot, ScatterElements, ScatterNd,
|
||||
Slice, Topk,
|
||||
},
|
||||
change_axes::AxisOp,
|
||||
cnn::{Conv, Deconv},
|
||||
einsum::EinSum,
|
||||
@@ -467,6 +470,78 @@ pub fn new_op_from_onnx(
|
||||
|
||||
// Extract the max value
|
||||
}
|
||||
"ScatterNd" => {
|
||||
if inputs.len() != 3 {
|
||||
return Err(Box::new(GraphError::InvalidDims(
|
||||
idx,
|
||||
"scatter nd".to_string(),
|
||||
)));
|
||||
};
|
||||
// just verify it deserializes correctly
|
||||
let _op = load_op::<ScatterNd>(node.op(), idx, node.op().name().to_string())?;
|
||||
|
||||
let mut op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::ScatterND {
|
||||
constant_idx: None,
|
||||
});
|
||||
|
||||
// if param_visibility.is_public() {
|
||||
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
|
||||
inputs[1].decrement_use();
|
||||
deleted_indices.push(inputs.len() - 1);
|
||||
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::ScatterND {
|
||||
constant_idx: Some(c.raw_values.map(|x| x as usize)),
|
||||
})
|
||||
}
|
||||
// }
|
||||
|
||||
if inputs[1].opkind().is_input() {
|
||||
inputs[1].replace_opkind(SupportedOp::Input(crate::circuit::ops::Input {
|
||||
scale: 0,
|
||||
datum_type: InputType::TDim,
|
||||
}));
|
||||
inputs[1].bump_scale(0);
|
||||
}
|
||||
|
||||
op
|
||||
}
|
||||
|
||||
"GatherNd" => {
|
||||
if inputs.len() != 2 {
|
||||
return Err(Box::new(GraphError::InvalidDims(
|
||||
idx,
|
||||
"gather nd".to_string(),
|
||||
)));
|
||||
};
|
||||
let op = load_op::<GatherNd>(node.op(), idx, node.op().name().to_string())?;
|
||||
let batch_dims = op.batch_dims;
|
||||
|
||||
let mut op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::GatherND {
|
||||
batch_dims,
|
||||
indices: None,
|
||||
});
|
||||
|
||||
// if param_visibility.is_public() {
|
||||
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
|
||||
inputs[1].decrement_use();
|
||||
deleted_indices.push(inputs.len() - 1);
|
||||
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::GatherND {
|
||||
batch_dims,
|
||||
indices: Some(c.raw_values.map(|x| x as usize)),
|
||||
})
|
||||
}
|
||||
// }
|
||||
|
||||
if inputs[1].opkind().is_input() {
|
||||
inputs[1].replace_opkind(SupportedOp::Input(crate::circuit::ops::Input {
|
||||
scale: 0,
|
||||
datum_type: InputType::TDim,
|
||||
}));
|
||||
inputs[1].bump_scale(0);
|
||||
}
|
||||
|
||||
op
|
||||
}
|
||||
|
||||
"GatherElements" => {
|
||||
if inputs.len() != 2 {
|
||||
return Err(Box::new(GraphError::InvalidDims(
|
||||
|
||||
@@ -484,7 +484,6 @@ fn get_srs(
|
||||
srs_path,
|
||||
settings_path,
|
||||
logrows,
|
||||
CheckMode::SAFE,
|
||||
))
|
||||
.map_err(|e| {
|
||||
let err_str = format!("Failed to get srs: {}", e);
|
||||
@@ -1104,6 +1103,7 @@ fn ezkl(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||
m.add_class::<PyG1Affine>()?;
|
||||
m.add_class::<PyG1>()?;
|
||||
m.add_class::<PyTestDataSource>()?;
|
||||
m.add("__version__", env!("CARGO_PKG_VERSION"))?;
|
||||
m.add_function(wrap_pyfunction!(felt_to_big_endian, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(felt_to_int, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(felt_to_float, m)?)?;
|
||||
|
||||
@@ -673,6 +673,68 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
Tensor::new(Some(&res), &dims)
|
||||
}
|
||||
|
||||
/// Set a slice of the Tensor.
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// let mut a = Tensor::<i32>::new(Some(&[1, 2, 3, 4, 5, 6]), &[2, 3]).unwrap();
|
||||
/// let b = Tensor::<i32>::new(Some(&[1, 2, 3, 1, 2, 3]), &[2, 3]).unwrap();
|
||||
/// a.set_slice(&[1..2], &Tensor::<i32>::new(Some(&[1, 2, 3]), &[1, 3]).unwrap()).unwrap();
|
||||
/// assert_eq!(a, b);
|
||||
/// ```
|
||||
pub fn set_slice(
|
||||
&mut self,
|
||||
indices: &[Range<usize>],
|
||||
value: &Tensor<T>,
|
||||
) -> Result<(), TensorError>
|
||||
where
|
||||
T: Send + Sync,
|
||||
{
|
||||
if indices.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
if self.dims.len() < indices.len() {
|
||||
return Err(TensorError::DimError(format!(
|
||||
"The dimensionality of the slice {:?} is greater than the tensor's {:?}",
|
||||
indices, self.dims
|
||||
)));
|
||||
}
|
||||
|
||||
// if indices weren't specified we fill them in as required
|
||||
let mut full_indices = indices.to_vec();
|
||||
|
||||
let omitted_dims = (indices.len()..self.dims.len())
|
||||
.map(|i| self.dims[i])
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
for dim in &omitted_dims {
|
||||
full_indices.push(0..*dim);
|
||||
}
|
||||
|
||||
let full_dims = full_indices
|
||||
.iter()
|
||||
.map(|x| x.end - x.start)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// now broadcast the value to the full dims
|
||||
let value = value.expand(&full_dims)?;
|
||||
|
||||
let cartesian_coord: Vec<Vec<usize>> = full_indices
|
||||
.iter()
|
||||
.cloned()
|
||||
.multi_cartesian_product()
|
||||
.collect();
|
||||
|
||||
let _ = cartesian_coord
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, e)| {
|
||||
self.set(e, value[i].clone());
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get the array index from rows / columns indices.
|
||||
///
|
||||
/// ```
|
||||
|
||||
@@ -2,7 +2,7 @@ use super::TensorError;
|
||||
use crate::tensor::{Tensor, TensorType};
|
||||
use itertools::Itertools;
|
||||
use maybe_rayon::{
|
||||
iter::IndexedParallelIterator, iter::IntoParallelRefMutIterator, iter::ParallelIterator,
|
||||
iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator},
|
||||
prelude::IntoParallelRefIterator,
|
||||
};
|
||||
use std::collections::{HashMap, HashSet};
|
||||
@@ -1328,6 +1328,316 @@ pub fn gather_elements<T: TensorType + Send + Sync>(
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
/// Gather ND.
|
||||
/// # Arguments
|
||||
/// * `input` - Tensor
|
||||
/// * `index` - Tensor of indices to gather
|
||||
/// * `batch_dims` - Number of batch dimensions
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::tensor::ops::gather_nd;
|
||||
/// let x = Tensor::<i128>::new(
|
||||
/// Some(&[0, 1, 2, 3]),
|
||||
/// &[2, 2],
|
||||
/// ).unwrap();
|
||||
/// let index = Tensor::<usize>::new(
|
||||
/// Some(&[0, 0, 1, 1]),
|
||||
/// &[2, 2],
|
||||
/// ).unwrap();
|
||||
/// let result = gather_nd(&x, &index, 0).unwrap();
|
||||
/// let expected = Tensor::<i128>::new(Some(&[0, 3]), &[2]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// let index = Tensor::<usize>::new(
|
||||
/// Some(&[1, 0]),
|
||||
/// &[2, 1],
|
||||
/// ).unwrap();
|
||||
/// let result = gather_nd(&x, &index, 0).unwrap();
|
||||
/// let expected = Tensor::<i128>::new(Some(&[2, 3, 0, 1]), &[2, 2]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// let x = Tensor::<i128>::new(
|
||||
/// Some(&[0, 1, 2, 3, 4, 5, 6, 7]),
|
||||
/// &[2, 2, 2],
|
||||
/// ).unwrap();
|
||||
/// let index = Tensor::<usize>::new(
|
||||
/// Some(&[0, 1, 1, 0]),
|
||||
/// &[2, 2],
|
||||
/// ).unwrap();
|
||||
/// let result = gather_nd(&x, &index, 0).unwrap();
|
||||
/// let expected = Tensor::<i128>::new(Some(&[2, 3, 4, 5]), &[2, 2]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// let index = Tensor::<usize>::new(
|
||||
/// Some(&[0, 1, 1, 0]),
|
||||
/// &[2, 1, 2],
|
||||
/// ).unwrap();
|
||||
/// let result = gather_nd(&x, &index, 0).unwrap();
|
||||
/// let expected = Tensor::<i128>::new(Some(&[2, 3, 4, 5]), &[2, 1, 2]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// let index = Tensor::<usize>::new(
|
||||
/// Some(&[1, 0]),
|
||||
/// &[2, 1],
|
||||
/// ).unwrap();
|
||||
/// let result = gather_nd(&x, &index, 1).unwrap();
|
||||
/// let expected = Tensor::<i128>::new(Some(&[2, 3, 4, 5]), &[2, 2]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// let index = Tensor::<usize>::new(
|
||||
/// Some(&[0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1]),
|
||||
/// &[2, 2, 3],
|
||||
/// ).unwrap();
|
||||
/// let result = gather_nd(&x, &index, 0).unwrap();
|
||||
/// let expected = Tensor::<i128>::new(Some(&[2, 3, 4, 5]), &[2, 2]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// let index = Tensor::<usize>::new(
|
||||
/// Some(&[0, 1, 0, 0, 1, 1, 1, 0]),
|
||||
/// &[2, 2, 2],
|
||||
/// ).unwrap();
|
||||
/// let result = gather_nd(&x, &index, 0).unwrap();
|
||||
/// let expected = Tensor::<i128>::new(Some(&[2, 3, 0, 1, 6, 7, 4, 5]), &[2, 2, 2]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// let index = Tensor::<usize>::new(
|
||||
/// Some(&[0, 1, 0, 1, 1, 1]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let result = gather_nd(&x, &index, 0).unwrap();
|
||||
/// let expected = Tensor::<i128>::new(Some(&[2, 7]), &[2]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
pub fn gather_nd<T: TensorType + Send + Sync>(
|
||||
input: &Tensor<T>,
|
||||
index: &Tensor<usize>,
|
||||
batch_dims: usize,
|
||||
) -> Result<Tensor<T>, TensorError> {
|
||||
// Calculate the output tensor size
|
||||
let index_dims = index.dims().to_vec();
|
||||
let input_dims = input.dims().to_vec();
|
||||
let last_value = index_dims
|
||||
.last()
|
||||
.ok_or(TensorError::DimMismatch("gather_nd".to_string()))?;
|
||||
if last_value > &(input_dims.len() - batch_dims) {
|
||||
return Err(TensorError::DimMismatch("gather_nd".to_string()));
|
||||
}
|
||||
|
||||
let output_size =
|
||||
// If indices_shape[-1] == r-b, since the rank of indices is q,
|
||||
// indices can be thought of as N (q-b-1)-dimensional tensors containing 1-D tensors of dimension r-b,
|
||||
// where N is an integer equals to the product of 1 and all the elements in the batch dimensions of the indices_shape.
|
||||
// Let us think of each such r-b ranked tensor as indices_slice.
|
||||
// Each scalar value corresponding to data[0:b-1,indices_slice] is filled into
|
||||
// the corresponding location of the (q-b-1)-dimensional tensor to form the output tensor
|
||||
// if indices_shape[-1] < r-b, since the rank of indices is q, indices can be thought of as N (q-b-1)-dimensional tensor containing 1-D tensors of dimension < r-b.
|
||||
// Let us think of each such tensors as indices_slice.
|
||||
// Each tensor slice corresponding to data[0:b-1, indices_slice , :] is filled into the corresponding location of the (q-b-1)-dimensional tensor to form the output tensor
|
||||
{
|
||||
let output_rank = input_dims.len() + index_dims.len() - 1 - batch_dims - last_value;
|
||||
|
||||
let mut dims = index_dims[..index_dims.len() - 1].to_vec();
|
||||
let input_offset = batch_dims + last_value;
|
||||
dims.extend(input_dims[input_offset..input_dims.len()].to_vec());
|
||||
|
||||
assert_eq!(output_rank, dims.len());
|
||||
dims
|
||||
|
||||
};
|
||||
|
||||
// cartesian coord over batch dims
|
||||
let mut batch_cartesian_coord = input_dims[0..batch_dims]
|
||||
.iter()
|
||||
.map(|x| 0..*x)
|
||||
.multi_cartesian_product()
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
if batch_cartesian_coord.is_empty() {
|
||||
batch_cartesian_coord.push(vec![]);
|
||||
}
|
||||
|
||||
let outputs = batch_cartesian_coord
|
||||
.par_iter()
|
||||
.map(|batch_coord| {
|
||||
let batch_slice = batch_coord.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
|
||||
let mut index_slice = index.get_slice(&batch_slice)?;
|
||||
index_slice.reshape(&index.dims()[batch_dims..])?;
|
||||
let mut input_slice = input.get_slice(&batch_slice)?;
|
||||
input_slice.reshape(&input.dims()[batch_dims..])?;
|
||||
|
||||
let mut inner_cartesian_coord = index_slice.dims()[0..index_slice.dims().len() - 1]
|
||||
.iter()
|
||||
.map(|x| 0..*x)
|
||||
.multi_cartesian_product()
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
if inner_cartesian_coord.is_empty() {
|
||||
inner_cartesian_coord.push(vec![]);
|
||||
}
|
||||
|
||||
let output = inner_cartesian_coord
|
||||
.iter()
|
||||
.map(|coord| {
|
||||
let slice = coord
|
||||
.iter()
|
||||
.map(|x| *x..*x + 1)
|
||||
.chain(batch_coord.iter().map(|x| *x..*x + 1))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let index_slice = index_slice
|
||||
.get_slice(&slice)
|
||||
.unwrap()
|
||||
.iter()
|
||||
.map(|x| *x..*x + 1)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
input_slice.get_slice(&index_slice).unwrap()
|
||||
})
|
||||
.collect::<Tensor<_>>();
|
||||
|
||||
output.combine()
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
let mut outputs = outputs.into_iter().flatten().collect::<Tensor<_>>();
|
||||
|
||||
outputs.reshape(&output_size)?;
|
||||
|
||||
Ok(outputs)
|
||||
}
|
||||
|
||||
/// Scatter ND.
|
||||
/// This operator is the inverse of GatherND.
|
||||
/// # Arguments
|
||||
/// * `input` - Tensor
|
||||
/// * `index` - Tensor of indices to scatter
|
||||
/// * `src` - Tensor of src
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::tensor::ops::scatter_nd;
|
||||
/// let x = Tensor::<i128>::new(
|
||||
/// Some(&[1, 2, 3, 4, 5, 6, 7, 8]),
|
||||
/// &[8],
|
||||
/// ).unwrap();
|
||||
///
|
||||
/// let index = Tensor::<usize>::new(
|
||||
/// Some(&[4, 3, 1, 7]),
|
||||
/// &[4, 1],
|
||||
/// ).unwrap();
|
||||
/// let src = Tensor::<i128>::new(
|
||||
/// Some(&[9, 10, 11, 12]),
|
||||
/// &[4],
|
||||
/// ).unwrap();
|
||||
/// let result = scatter_nd(&x, &index, &src).unwrap();
|
||||
/// let expected = Tensor::<i128>::new(Some(&[1, 11, 3, 10, 9, 6, 7, 12]), &[8]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// let x = Tensor::<i128>::new(
|
||||
/// Some(&[1, 2, 3, 4, 5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1,
|
||||
/// 1, 2, 3, 4, 5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1,
|
||||
/// 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 7, 8,
|
||||
/// 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 7, 8]),
|
||||
/// &[4, 4, 4],
|
||||
/// ).unwrap();
|
||||
///
|
||||
/// let index = Tensor::<usize>::new(
|
||||
/// Some(&[0, 2]),
|
||||
/// &[2, 1],
|
||||
/// ).unwrap();
|
||||
///
|
||||
/// let src = Tensor::<i128>::new(
|
||||
/// Some(&[5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8,
|
||||
/// 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4,
|
||||
/// ]),
|
||||
/// &[2, 4, 4],
|
||||
/// ).unwrap();
|
||||
///
|
||||
/// let result = scatter_nd(&x, &index, &src).unwrap();
|
||||
///
|
||||
/// let expected = Tensor::<i128>::new(
|
||||
/// Some(&[5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8,
|
||||
/// 1, 2, 3, 4, 5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1,
|
||||
/// 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4,
|
||||
/// 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 7, 8]),
|
||||
/// &[4, 4, 4],
|
||||
/// ).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// let x = Tensor::<i128>::new(
|
||||
/// Some(&[1, 2, 3, 4, 5, 6, 7, 8]),
|
||||
/// &[2, 4],
|
||||
/// ).unwrap();
|
||||
///
|
||||
/// let index = Tensor::<usize>::new(
|
||||
/// Some(&[0, 1]),
|
||||
/// &[2, 1],
|
||||
/// ).unwrap();
|
||||
/// let src = Tensor::<i128>::new(
|
||||
/// Some(&[9, 10]),
|
||||
/// &[2],
|
||||
/// ).unwrap();
|
||||
/// let result = scatter_nd(&x, &index, &src).unwrap();
|
||||
/// let expected = Tensor::<i128>::new(Some(&[9, 9, 9, 9, 10, 10, 10, 10]), &[2, 4]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// let x = Tensor::<i128>::new(
|
||||
/// Some(&[1, 2, 3, 4, 5, 6, 7, 8]),
|
||||
/// &[2, 4],
|
||||
/// ).unwrap();
|
||||
///
|
||||
/// let index = Tensor::<usize>::new(
|
||||
/// Some(&[0, 1]),
|
||||
/// &[1, 1, 2],
|
||||
/// ).unwrap();
|
||||
/// let src = Tensor::<i128>::new(
|
||||
/// Some(&[9]),
|
||||
/// &[1, 1],
|
||||
/// ).unwrap();
|
||||
/// let result = scatter_nd(&x, &index, &src).unwrap();
|
||||
/// let expected = Tensor::<i128>::new(Some(&[1, 9, 3, 4, 5, 6, 7, 8]), &[2, 4]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ````
|
||||
///
|
||||
pub fn scatter_nd<T: TensorType + Send + Sync>(
|
||||
input: &Tensor<T>,
|
||||
index: &Tensor<usize>,
|
||||
src: &Tensor<T>,
|
||||
) -> Result<Tensor<T>, TensorError> {
|
||||
// Calculate the output tensor size
|
||||
let index_dims = index.dims().to_vec();
|
||||
let input_dims = input.dims().to_vec();
|
||||
let last_value = index_dims
|
||||
.last()
|
||||
.ok_or(TensorError::DimMismatch("scatter_nd".to_string()))?;
|
||||
if last_value > &input_dims.len() {
|
||||
return Err(TensorError::DimMismatch("scatter_nd".to_string()));
|
||||
}
|
||||
|
||||
let mut output = input.clone();
|
||||
|
||||
let cartesian_coord = index_dims[0..index_dims.len() - 1]
|
||||
.iter()
|
||||
.map(|x| 0..*x)
|
||||
.multi_cartesian_product()
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
cartesian_coord
|
||||
.iter()
|
||||
.map(|coord| {
|
||||
let slice = coord.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
|
||||
let index_val = index.get_slice(&slice)?;
|
||||
let index_slice = index_val.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
|
||||
let src_val = src.get_slice(&slice)?;
|
||||
output.set_slice(&index_slice, &src_val)?;
|
||||
Ok(())
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
fn axes_op<T: TensorType + Send + Sync>(
|
||||
a: &Tensor<T>,
|
||||
axes: &[usize],
|
||||
|
||||
@@ -494,7 +494,12 @@ impl<F: PrimeField + TensorType + PartialOrd> ValTensor<F> {
|
||||
}
|
||||
_ => return Err(Box::new(TensorError::WrongMethod)),
|
||||
};
|
||||
Ok(integer_evals.into_iter().into())
|
||||
let mut tensor: Tensor<i128> = integer_evals.into_iter().into();
|
||||
match tensor.reshape(self.dims()) {
|
||||
_ => {}
|
||||
};
|
||||
|
||||
Ok(tensor)
|
||||
}
|
||||
|
||||
/// Calls `pad_to_zero_rem` on the inner tensor.
|
||||
|
||||
@@ -193,7 +193,7 @@ mod native_tests {
|
||||
"1l_tiny_div",
|
||||
];
|
||||
|
||||
const TESTS: [&str; 77] = [
|
||||
const TESTS: [&str; 79] = [
|
||||
"1l_mlp", //0
|
||||
"1l_slice",
|
||||
"1l_concat",
|
||||
@@ -275,6 +275,8 @@ mod native_tests {
|
||||
"ltsf",
|
||||
"remainder", //75
|
||||
"bitshift",
|
||||
"gather_nd",
|
||||
"scatter_nd",
|
||||
];
|
||||
|
||||
const WASM_TESTS: [&str; 46] = [
|
||||
@@ -502,7 +504,7 @@ mod native_tests {
|
||||
}
|
||||
});
|
||||
|
||||
seq!(N in 0..=76 {
|
||||
seq!(N in 0..=78 {
|
||||
|
||||
#(#[test_case(TESTS[N])])*
|
||||
#[ignore]
|
||||
@@ -589,13 +591,16 @@ mod native_tests {
|
||||
|
||||
#(#[test_case(TESTS[N])])*
|
||||
fn mock_large_batch_public_outputs_(test: &str) {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
let large_batch_dir = &format!("large_batches_{}", test);
|
||||
crate::native_tests::mk_data_batches_(path, test, &large_batch_dir, 10);
|
||||
mock(path, large_batch_dir.to_string(), "private", "private", "public", 10, "resources", None, 0.0);
|
||||
test_dir.close().unwrap();
|
||||
// currently variable output rank is not supported in ONNX
|
||||
if test != "gather_nd" {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
let large_batch_dir = &format!("large_batches_{}", test);
|
||||
crate::native_tests::mk_data_batches_(path, test, &large_batch_dir, 10);
|
||||
mock(path, large_batch_dir.to_string(), "private", "private", "public", 10, "resources", None, 0.0);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
#(#[test_case(TESTS[N])])*
|
||||
|
||||
Binary file not shown.
Binary file not shown.
File diff suppressed because one or more lines are too long
Binary file not shown.
Reference in New Issue
Block a user