Compare commits

...

4 Commits

Author SHA1 Message Date
jmjac
de9e3f2673 Add __version__ to python bindings (#739) 2024-03-13 14:22:20 +00:00
dante
a1450f8df7 feat: gather_nd/scatter_nd support (#737) 2024-03-11 22:05:40 +00:00
dante
ea535e2ecd refactor: use linear index constraints for gather and scatter (#735) 2024-03-09 18:00:21 +00:00
Alexander Camuto
f8aa91ed08 fix: windows compile 2024-03-06 11:40:44 +00:00
20 changed files with 1196 additions and 259 deletions

View File

@@ -0,0 +1,48 @@
from torch import nn
import json
import numpy as np
import tf2onnx
import tensorflow as tf
from tensorflow.keras.layers import *
from tensorflow.keras.models import Model
# gather_nd in tf then export to onnx
x = in1 = Input((15, 18,))
w = in2 = Input((15, 1), dtype=tf.int32)
x = tf.gather_nd(x, w, batch_dims=1)
tm = Model((in1, in2), x )
tm.summary()
tm.compile(optimizer='adam', loss='mse')
shape = [1, 15, 18]
index_shape = [1, 15, 1]
# After training, export to onnx (network.onnx) and create a data file (input.json)
x = 0.1*np.random.rand(1,*shape)
# w = random int tensor
w = np.random.randint(0, 10, index_shape)
spec = tf.TensorSpec(shape, tf.float32, name='input_0')
index_spec = tf.TensorSpec(index_shape, tf.int32, name='input_1')
model_path = "network.onnx"
tf2onnx.convert.from_keras(tm, input_signature=[spec, index_spec], inputs_as_nchw=['input_0', 'input_1'], opset=12, output_path=model_path)
d = x.reshape([-1]).tolist()
d1 = w.reshape([-1]).tolist()
data = dict(
input_data=[d, d1],
)
# Serialize data into file:
json.dump(data, open("input.json", 'w'))

File diff suppressed because one or more lines are too long

Binary file not shown.

View File

@@ -0,0 +1,76 @@
import torch
import torch.nn as nn
import sys
import json
sys.path.append("..")
class Model(nn.Module):
"""
Just one Linear layer
"""
def __init__(self, configs):
super(Model, self).__init__()
self.seq_len = configs.seq_len
self.pred_len = configs.pred_len
# Use this line if you want to visualize the weights
# self.Linear.weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len]))
self.channels = configs.enc_in
self.individual = configs.individual
if self.individual:
self.Linear = nn.ModuleList()
for i in range(self.channels):
self.Linear.append(nn.Linear(self.seq_len,self.pred_len))
else:
self.Linear = nn.Linear(self.seq_len, self.pred_len)
def forward(self, x):
# x: [Batch, Input length, Channel]
if self.individual:
output = torch.zeros([x.size(0),self.pred_len,x.size(2)],dtype=x.dtype).to(x.device)
for i in range(self.channels):
output[:,:,i] = self.Linear[i](x[:,:,i])
x = output
else:
x = self.Linear(x.permute(0,2,1)).permute(0,2,1)
return x # [Batch, Output length, Channel]
class Configs:
def __init__(self, seq_len, pred_len, enc_in=321, individual=True):
self.seq_len = seq_len
self.pred_len = pred_len
self.enc_in = enc_in
self.individual = individual
model = 'Linear'
seq_len = 10
pred_len = 4
enc_in = 3
configs = Configs(seq_len, pred_len, enc_in, True)
circuit = Model(configs)
x = torch.randn(1, seq_len, pred_len)
torch.onnx.export(circuit, x, "network.onnx",
export_params=True, # store the trained parameter weights inside the model file
opset_version=15, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
# the model's input names
input_names=['input'],
output_names=['output'], # the model's output names
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
'output': {0: 'batch_size'}})
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
data = dict(
input_data=[d1],
)
# Serialize data into file:
json.dump(data, open("input.json", 'w'))

View File

@@ -0,0 +1 @@
{"input_data": [[0.1874287724494934, 1.0498261451721191, 0.22384068369865417, 1.048445224761963, -0.5670360326766968, -0.38653188943862915, 0.12878702580928802, -2.3675858974456787, 0.5800458192825317, -0.43653929233551025, -0.2511898875236511, 0.3324051797389984, 0.27960312366485596, 0.4763695001602173, 0.3796705901622772, 1.1334782838821411, -0.87981778383255, -1.2451434135437012, 0.7672272324562073, -0.24404007196426392, -0.6875824928283691, 0.3619358539581299, -0.10131897777318954, 0.7169521450996399, 1.6585893630981445, -0.5451845526695251, 0.429487019777298, 0.7426952123641968, -0.2543637454509735, 0.06546942889690399, 0.7939824461936951, 0.1579471379518509, -0.043604474514722824, -0.8621711730957031, -0.5344759821891785, -0.05880478024482727, -0.17351101338863373, 0.5095029473304749, -0.7864817976951599, -0.449171245098114]]}

Binary file not shown.

View File

@@ -646,14 +646,13 @@ fn select<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 2],
dim_indices: ValTensor<F>,
) -> Result<ValTensor<F>, Box<dyn Error>> {
let (mut input, index) = (values[0].clone(), values[1].clone());
input.flatten();
if !(dim_indices.all_prev_assigned() || region.is_dummy()) {
return Err("dim_indices must be assigned".into());
}
// these will be assigned as constants
let dim_indices: ValTensor<F> =
Tensor::from((0..input.len() as u64).map(|x| ValType::Constant(F::from(x)))).into();
let is_assigned = !input.any_unknowns()? && !index.any_unknowns()?;
@@ -929,91 +928,25 @@ pub(crate) fn gather<F: PrimeField + TensorType + PartialOrd>(
values: &[ValTensor<F>; 2],
dim: usize,
) -> Result<ValTensor<F>, Box<dyn Error>> {
let (mut input, mut index_clone) = (values[0].clone(), values[1].clone());
let (input, mut index_clone) = (values[0].clone(), values[1].clone());
index_clone.flatten();
if index_clone.is_singleton() {
index_clone.reshape(&[1])?;
}
let mut assigned_len = vec![];
if !input.all_prev_assigned() {
input = region.assign(&config.custom_gates.inputs[0], &input)?;
assigned_len.push(input.len());
}
if !index_clone.all_prev_assigned() {
index_clone = region.assign(&config.custom_gates.inputs[1], &index_clone)?;
assigned_len.push(index_clone.len());
}
if !assigned_len.is_empty() {
// safe to unwrap since we've just checked it has at least one element
region.increment(*assigned_len.iter().max().unwrap());
}
// Calculate the output tensor size
let input_dims = input.dims();
let mut output_size = input_dims.to_vec();
output_size[dim] = index_clone.dims()[0];
// these will be assigned as constants
let mut indices = Tensor::from((0..input.dims()[dim] as u64).map(|x| F::from(x)));
indices.set_visibility(&crate::graph::Visibility::Fixed);
let indices = region.assign(&config.custom_gates.inputs[1], &indices.try_into()?)?;
region.increment(indices.len());
let linear_index =
linearize_element_index(config, region, &[index_clone], input_dims, dim, true)?;
let mut iteration_dims = output_size.clone();
iteration_dims[dim] = 1;
let mut output = select(config, region, &[input, linear_index])?;
// Allocate memory for the output tensor
let cartesian_coord = iteration_dims
.iter()
.map(|x| 0..*x)
.multi_cartesian_product()
.collect::<Vec<_>>();
let mut results = HashMap::new();
for coord in cartesian_coord {
let mut slice = coord.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
slice[dim] = 0..input_dims[dim];
let mut sliced_input = input.get_slice(&slice)?;
sliced_input.flatten();
let res = select(
config,
region,
&[sliced_input, index_clone.clone()],
indices.clone(),
)?;
results.insert(coord, res);
}
// Allocate memory for the output tensor
let cartesian_coord = output_size
.iter()
.map(|x| 0..*x)
.multi_cartesian_product()
.collect::<Vec<_>>();
let mut output = Tensor::new(None, &output_size)?.par_enum_map(|i, _: ValType<F>| {
let coord = cartesian_coord[i].clone();
let mut key = coord.clone();
key[dim] = 0;
let result = &results.get(&key).ok_or("missing result")?;
let o = result.get_inner_tensor().map_err(|_| "missing tensor")?[coord[dim]].clone();
Ok::<ValType<F>, region::RegionError>(o)
})?;
// Reshape the output tensor
if index_clone.is_singleton() {
output_size.remove(dim);
}
output.reshape(&output_size)?;
Ok(output.into())
Ok(output)
}
/// Gather accumulated layout
@@ -1022,82 +955,387 @@ pub(crate) fn gather_elements<F: PrimeField + TensorType + PartialOrd>(
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 2],
dim: usize,
) -> Result<ValTensor<F>, Box<dyn Error>> {
let (mut input, mut index) = (values[0].clone(), values[1].clone());
) -> Result<(ValTensor<F>, ValTensor<F>), Box<dyn Error>> {
let (input, index) = (values[0].clone(), values[1].clone());
assert_eq!(input.dims().len(), index.dims().len());
if !input.all_prev_assigned() {
input = region.assign(&config.custom_gates.inputs[0], &input)?;
}
if !index.all_prev_assigned() {
index = region.assign(&config.custom_gates.inputs[1], &index)?;
}
region.increment(std::cmp::max(input.len(), index.len()));
// Calculate the output tensor size
let input_dims = input.dims();
let output_size = index.dims().to_vec();
// these will be assigned as constants
let mut indices = Tensor::from((0..input_dims[dim] as u64).map(|x| F::from(x)));
indices.set_visibility(&crate::graph::Visibility::Fixed);
let indices = region.assign(&config.custom_gates.inputs[1], &indices.try_into()?)?;
region.increment(indices.len());
let linear_index = linearize_element_index(config, region, &[index], input.dims(), dim, false)?;
let mut iteration_dims = output_size.clone();
iteration_dims[dim] = 1;
let mut output = select(config, region, &[input, linear_index.clone()])?;
output.reshape(&output_size)?;
Ok((output, linear_index))
}
/// Gather accumulated layout
pub(crate) fn gather_nd<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 2],
batch_dims: usize,
) -> Result<(ValTensor<F>, ValTensor<F>), Box<dyn Error>> {
let (input, index) = (values[0].clone(), values[1].clone());
let index_dims = index.dims().to_vec();
let input_dims = input.dims().to_vec();
let last_value = index_dims
.last()
.ok_or(TensorError::DimMismatch("gather_nd".to_string()))?;
if index_dims.last() > Some(&(input_dims.len() - batch_dims)) {
return Err(TensorError::DimMismatch("gather_nd".to_string()).into());
}
let output_size =
// If indices_shape[-1] == r-b, since the rank of indices is q,
// indices can be thought of as N (q-b-1)-dimensional tensors containing 1-D tensors of dimension r-b,
// where N is an integer equals to the product of 1 and all the elements in the batch dimensions of the indices_shape.
// Let us think of each such r-b ranked tensor as indices_slice.
// Each scalar value corresponding to data[0:b-1,indices_slice] is filled into
// the corresponding location of the (q-b-1)-dimensional tensor to form the output tensor
// if indices_shape[-1] < r-b, since the rank of indices is q, indices can be thought of as N (q-b-1)-dimensional tensor containing 1-D tensors of dimension < r-b.
// Let us think of each such tensors as indices_slice.
// Each tensor slice corresponding to data[0:b-1, indices_slice , :] is filled into the corresponding location of the (q-b-1)-dimensional tensor to form the output tensor
{
let output_rank = input_dims.len() + index_dims.len() - 1 - batch_dims - last_value;
let mut dims = index_dims[..index_dims.len() - 1].to_vec();
let input_offset = batch_dims + last_value;
dims.extend(input_dims[input_offset..input_dims.len()].to_vec());
assert_eq!(output_rank, dims.len());
dims
};
let linear_index = linearize_nd_index(config, region, &[index], input.dims(), batch_dims)?;
let mut output = select(config, region, &[input, linear_index.clone()])?;
output.reshape(&output_size)?;
Ok((output, linear_index))
}
/// Takes a tensor representing a multi-dimensional index and returns a tensor representing the linearized index.
/// The linearized index is the index of the element in the flattened tensor.
/// FOr instance if the dims is [3,5,2], the linearized index of [2] at dim 1 is 2*5 + 3 = 13
pub(crate) fn linearize_element_index<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
dims: &[usize],
dim: usize,
is_flat_index: bool,
) -> Result<ValTensor<F>, Box<dyn Error>> {
let index = values[0].clone();
if !is_flat_index {
assert_eq!(index.dims().len(), dims.len());
// if the index is already flat, return it
if index.dims().len() == 1 {
return Ok(index);
}
}
let dim_multiplier: Tensor<usize> = Tensor::new(None, &[dims.len()])?;
let dim_multiplier: Tensor<F> = dim_multiplier.par_enum_map(|i, _| {
let mut res = 1;
for dim in dims.iter().skip(i + 1) {
res *= dim;
}
Ok::<_, region::RegionError>(F::from(res as u64))
})?;
let iteration_dims = if is_flat_index {
let mut dims = dims.to_vec();
dims[dim] = index.len();
dims
} else {
index.dims().to_vec()
};
// Allocate memory for the output tensor
let cartesian_coord = iteration_dims
.iter()
.map(|x| 0..*x)
.multi_cartesian_product()
.collect::<Vec<_>>();
let mut results = HashMap::new();
let val_dim_multiplier: ValTensor<F> = dim_multiplier
.get_slice(&[dim..dim + 1])?
.map(|x| ValType::Constant(x))
.into();
for coord in cartesian_coord {
let mut slice = coord.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
slice[dim] = 0..input_dims[dim];
let mut output = Tensor::new(None, &[cartesian_coord.len()])?;
let mut sliced_input = input.get_slice(&slice)?;
sliced_input.flatten();
let inner_loop_function = |i: usize, region: &mut RegionCtx<'_, F>| {
let coord = cartesian_coord[i].clone();
let slice: Vec<Range<usize>> = if is_flat_index {
coord[dim..dim + 1].iter().map(|x| *x..*x + 1).collect()
} else {
coord.iter().map(|x| *x..*x + 1).collect::<Vec<_>>()
};
slice[dim] = 0..output_size[dim];
let mut sliced_index = index.get_slice(&slice)?;
sliced_index.flatten();
let index_val = index.get_slice(&slice)?;
let res = select(
let mut const_offset = F::ZERO;
for i in 0..dims.len() {
if i != dim {
const_offset += F::from(coord[i] as u64) * dim_multiplier[i];
}
}
let const_offset = create_constant_tensor(const_offset, 1);
let res = pairwise(
config,
region,
&[sliced_input, sliced_index],
indices.clone(),
&[index_val, val_dim_multiplier.clone()],
BaseOp::Mult,
)?;
results.insert(coord, res);
}
let res = pairwise(config, region, &[res, const_offset], BaseOp::Add)?;
// Allocate memory for the output tensor
let cartesian_coord = output_size
Ok(res.get_inner_tensor()?[0].clone())
};
region.apply_in_loop(&mut output, inner_loop_function)?;
Ok(output.into())
}
/// Takes a tensor representing a nd index and returns a tensor representing the linearized index.
/// The linearized index is the index of the element in the flattened tensor.
/// Given data tensor of rank r >= 1, indices tensor of rank q >= 1, and batch_dims integer b, this operator gathers slices of data into an output tensor of rank q + r - indices_shape[-1] - 1 - b.
/// indices is an q-dimensional integer tensor, best thought of as a (q-1)-dimensional tensor of index-tuples into data, where each element defines a slice of data
/// batch_dims (denoted as b) is an integer indicating the number of batch dimensions, i.e the leading b number of dimensions of data tensor and indices are representing the batches, and the gather starts from the b+1 dimension.
/// Some salient points about the inputs rank and shape:
/// r >= 1 and q >= 1 are to be honored. There is no dependency condition to be met between ranks r and q
/// The first b dimensions of the shape of indices tensor and data tensor must be equal.
/// b < min(q, r) is to be honored.
/// The indices_shape[-1] should have a value between 1 (inclusive) and rank r-b (inclusive)
/// All values in indices are expected to be within bounds [-s, s-1] along axis of size s (i.e.) -data_shape[i] <= indices[...,i] <= data_shape[i] - 1. It is an error if any of the index values are out of bounds.
// The output is computed as follows:
/// The output tensor is obtained by mapping each index-tuple in the indices tensor to the corresponding slice of the input data.
/// If indices_shape[-1] > r-b => error condition
/// If indices_shape[-1] == r-b, since the rank of indices is q, indices can be thought of as N (q-b-1)-dimensional tensors containing 1-D tensors of dimension r-b, where N is an integer equals to the product of 1 and all the elements in the batch dimensions of the indices_shape.
/// Let us think of each such r-b ranked tensor as indices_slice. Each scalar value corresponding to data[0:b-1,indices_slice] is filled into the corresponding location of the (q-b-1)-dimensional tensor to form the output tensor (Example 1 below)
/// If indices_shape[-1] < r-b, since the rank of indices is q, indices can be thought of as N (q-b-1)-dimensional tensor containing 1-D tensors of dimension < r-b. Let us think of each such tensors as indices_slice. Each tensor slice corresponding to data[0:b-1, indices_slice , :] is filled into the corresponding location of the (q-b-1)-dimensional tensor to form the output tensor (Examples 2, 3, 4 and 5 below)
pub(crate) fn linearize_nd_index<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
dims: &[usize],
batch_dims: usize,
) -> Result<ValTensor<F>, Box<dyn Error>> {
let index = values[0].clone();
let index_dims = index.dims().to_vec();
let last_dim = index.dims().last().unwrap();
let input_rank = dims[batch_dims..].len();
let dim_multiplier: Tensor<usize> = Tensor::new(None, &[dims.len()])?;
let dim_multiplier: Tensor<F> = dim_multiplier.par_enum_map(|i, _| {
let mut res = 1;
for dim in dims.iter().skip(i + 1) {
res *= dim;
}
Ok::<_, region::RegionError>(F::from(res as u64))
})?;
let iteration_dims = index.dims()[0..batch_dims].to_vec();
let mut batch_cartesian_coord = iteration_dims
.iter()
.map(|x| 0..*x)
.multi_cartesian_product()
.collect::<Vec<_>>();
let output = Tensor::new(None, &output_size)?.par_enum_map(|i, _: ValType<F>| {
let coord = cartesian_coord[i].clone();
let mut key = coord.clone();
key[dim] = 0;
let result = &results.get(&key).ok_or("missing result")?;
let o = result.get_inner_tensor().map_err(|_| "missing tensor")?[coord[dim]].clone();
Ok::<ValType<F>, region::RegionError>(o)
})?;
if batch_cartesian_coord.is_empty() {
batch_cartesian_coord.push(vec![]);
}
let index_dim_multiplier: ValTensor<F> = dim_multiplier
.get_slice(&[batch_dims..dims.len()])?
.map(|x| ValType::Constant(x))
.into();
let mut outer_results = vec![];
for coord in batch_cartesian_coord {
let slice: Vec<Range<usize>> = coord.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
let mut index_slice = index.get_slice(&slice)?;
index_slice.reshape(&index_dims[batch_dims..])?;
// expand the index to the full dims by iterating over the rest of the dims and inserting constants
// eg in the case
// batch_dims = 0
// data = [[[0,1],[2,3]],[[4,5],[6,7]]] # data_shape = [2, 2, 2]
// indices = [[0,1],[1,0]] # indices_shape = [2, 2]
// output = [[2,3],[4,5]] # output_shape = [2, 2]
// the index should be expanded to the shape [2,2,3]: [[0,1,0],[0,1,1],[1,0,0],[1,0,1]]
let mut inner_cartesian_coord = index_slice.dims()[0..index_slice.dims().len() - 1]
.iter()
.map(|x| 0..*x)
.multi_cartesian_product()
.collect::<Vec<_>>();
if inner_cartesian_coord.is_empty() {
inner_cartesian_coord.push(vec![]);
}
let indices = if last_dim < &input_rank {
inner_cartesian_coord
.iter()
.map(|x| {
let slice = x.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
let index = index_slice.get_slice(&slice)?;
// map over cartesian coord of rest of dims and insert constants
let grid = (*last_dim..input_rank)
.map(|x| 0..dims[x])
.multi_cartesian_product();
Ok(grid
.map(|x| {
let index = index.clone();
let constant_valtensor: ValTensor<F> = Tensor::from(
x.into_iter().map(|x| ValType::Constant(F::from(x as u64))),
)
.into();
index.concat(constant_valtensor)
})
.collect::<Result<Vec<_>, TensorError>>()?)
})
.collect::<Result<Vec<_>, Box<dyn Error>>>()?
.into_iter()
.flatten()
.collect::<Vec<_>>()
} else {
inner_cartesian_coord
.iter()
.map(|x| {
let slice = x.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
Ok(index_slice.get_slice(&slice)?)
})
.collect::<Result<Vec<_>, Box<dyn Error>>>()?
};
let mut const_offset = F::ZERO;
for i in 0..batch_dims {
const_offset += F::from(coord[i] as u64) * dim_multiplier[i];
}
let const_offset = create_constant_tensor(const_offset, 1);
let mut results = vec![];
for index_val in indices {
let mut index_val = index_val.clone();
index_val.flatten();
let res = pairwise(
config,
region,
&[index_val.clone(), index_dim_multiplier.clone()],
BaseOp::Mult,
)?;
let res = res.concat(const_offset.clone())?;
let res = sum(config, region, &[res])?;
results.push(res.get_inner_tensor()?.clone());
// assert than res is less than the product of the dims
assert!(
res.get_int_evals()?
.iter()
.all(|x| *x < dims.iter().product::<usize>() as i128),
"res is greater than the product of the dims {} (coord={}, index_dim_multiplier={}, res={})",
dims.iter().product::<usize>(),
index_val.show(),
index_dim_multiplier.show(),
res.show()
);
}
let result_tensor = Tensor::from(results.into_iter());
outer_results.push(result_tensor.combine()?);
}
let output = Tensor::from(outer_results.into_iter());
let output = output.combine()?;
Ok(output.into())
}
pub(crate) fn get_missing_set_elements<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 2],
ordered: bool,
) -> Result<ValTensor<F>, Box<dyn Error>> {
let (mut input, fullset) = (values[0].clone(), values[1].clone());
let set_len = fullset.len();
input.flatten();
let is_assigned = !input.any_unknowns()? && !fullset.any_unknowns()?;
let mut claimed_output: ValTensor<F> = if is_assigned {
let input_evals = input.get_int_evals()?;
let mut fullset_evals = fullset.get_int_evals()?.into_iter().collect::<Vec<_>>();
// get the difference between the two vectors
for eval in input_evals.iter() {
// delete first occurence of that value
if let Some(pos) = fullset_evals.iter().position(|x| x == eval) {
fullset_evals.remove(pos);
}
}
// if fullset + input is the same length, then input is a subset of fullset, else randomly delete elements, this is a patch for
// the fact that we can't have a tensor of unknowns when using constant during gen-settings
if fullset_evals.len() != set_len - input.len() {
fullset_evals.truncate(set_len - input.len());
}
fullset_evals
.iter()
.map(|x| Value::known(i128_to_felt(*x)))
.collect::<Tensor<Value<F>>>()
.into()
} else {
let dim = fullset.len() - input.len();
Tensor::new(Some(&vec![Value::<F>::unknown(); dim]), &[dim])?.into()
};
// assign the claimed output
claimed_output = region.assign(&config.custom_gates.output, &claimed_output)?;
// input and claimed output should be the shuffles of fullset
// concatentate input and claimed output
let input_and_claimed_output = input.concat(claimed_output.clone())?;
// assert that this is a permutation/shuffle
shuffles(
config,
region,
&[input_and_claimed_output.clone()],
&[fullset.clone()],
)?;
if ordered {
// assert that the claimed output is sorted
claimed_output = _sort_ascending(config, region, &[claimed_output])?;
}
Ok(claimed_output)
}
/// Gather accumulated layout
pub(crate) fn scatter_elements<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
@@ -1105,88 +1343,158 @@ pub(crate) fn scatter_elements<F: PrimeField + TensorType + PartialOrd>(
values: &[ValTensor<F>; 3],
dim: usize,
) -> Result<ValTensor<F>, Box<dyn Error>> {
let (mut input, mut index, mut src) = (values[0].clone(), values[1].clone(), values[2].clone());
let (input, mut index, src) = (values[0].clone(), values[1].clone(), values[2].clone());
assert_eq!(input.dims().len(), index.dims().len());
let mut assigned_len = vec![];
if !input.all_prev_assigned() {
input = region.assign(&config.custom_gates.inputs[0], &input)?;
assigned_len.push(input.len());
}
if !index.all_prev_assigned() {
index = region.assign(&config.custom_gates.inputs[1], &index)?;
assigned_len.push(index.len());
}
if !src.all_prev_assigned() {
src = region.assign(&config.custom_gates.output, &src)?;
assigned_len.push(src.len());
region.increment(index.len());
}
if !assigned_len.is_empty() {
// safe to unwrap since we've just checked it has at least one element
region.increment(*assigned_len.iter().max().unwrap());
}
let is_assigned = !input.any_unknowns()? && !index.any_unknowns()? && !src.any_unknowns()?;
// Calculate the output tensor size
let input_dim = input.dims()[dim];
let output_size = index.dims().to_vec();
let claimed_output: ValTensor<F> = if is_assigned {
let input_inner = input.get_int_evals()?;
let index_inner = index.get_int_evals()?.map(|x| x as usize);
let src_inner = src.get_int_evals()?;
// these will be assigned as constants
let mut indices = Tensor::from((0..input_dim as u64).map(|x| F::from(x)));
indices.set_visibility(&crate::graph::Visibility::Fixed);
let indices = region.assign(&config.custom_gates.inputs[1], &indices.try_into()?)?;
region.increment(indices.len());
let res = tensor::ops::scatter(&input_inner, &index_inner, &src_inner, dim)?;
// Allocate memory for the output tensor
let cartesian_coord = output_size
.iter()
.map(|x| 0..*x)
.multi_cartesian_product()
.collect::<Vec<_>>();
let mut output: Tensor<()> = Tensor::new(None, &output_size)?;
let mut inner_loop_function = |i: usize, region: &mut RegionCtx<'_, F>| {
let coord = cartesian_coord[i].clone();
let index_val = index.get_inner_tensor()?.get(&coord);
let src_val = src.get_inner_tensor()?.get(&coord);
let src_valtensor: ValTensor<F> = Tensor::from([src_val.clone()].into_iter()).into();
let mut slice = coord.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
slice[dim] = 0..input_dim;
let mut sliced_input = input.get_slice(&slice)?;
sliced_input.flatten();
let index_valtensor: ValTensor<F> = Tensor::from([index_val.clone()].into_iter()).into();
let mask = equals(config, region, &[index_valtensor, indices.clone()])?;
let res = iff(config, region, &[mask, src_valtensor, sliced_input])?;
let input_cartesian_coord = slice.into_iter().multi_cartesian_product();
let mutable_input_inner = input.get_inner_tensor_mut()?;
for (i, r) in res.get_inner_tensor()?.iter().enumerate() {
let coord = input_cartesian_coord
.clone()
.nth(i)
.ok_or("invalid coord")?;
*mutable_input_inner.get_mut(&coord) = r.clone();
}
Ok(())
res.iter()
.map(|x| Value::known(i128_to_felt(*x)))
.collect::<Tensor<Value<F>>>()
.into()
} else {
Tensor::new(
Some(&vec![Value::<F>::unknown(); input.len()]),
&[input.len()],
)?
.into()
};
output
.iter_mut()
.enumerate()
.map(|(i, _)| inner_loop_function(i, region))
.collect::<Result<Vec<()>, Box<dyn Error>>>()?;
// assign the claimed output
let mut claimed_output = region.assign(&config.custom_gates.output, &claimed_output)?;
region.increment(claimed_output.len());
claimed_output.reshape(input.dims())?;
Ok(input)
// scatter elements is the inverse of gather elements
let (gather_src, linear_index) = gather_elements(
config,
region,
&[claimed_output.clone(), index.clone()],
dim,
)?;
// assert this is equal to the src
enforce_equality(config, region, &[gather_src, src])?;
let full_index_set: ValTensor<F> =
Tensor::from((0..input.len() as u64).map(|x| ValType::Constant(F::from(x)))).into();
let input_indices = get_missing_set_elements(
config,
region,
&[linear_index, full_index_set.clone()],
true,
)?;
claimed_output.flatten();
let (gather_input, _) = gather_elements(
config,
region,
&[claimed_output.clone(), input_indices.clone()],
0,
)?;
// assert this is a subset of the input
dynamic_lookup(
config,
region,
&[input_indices, gather_input],
&[full_index_set, input.clone()],
)?;
claimed_output.reshape(input.dims())?;
Ok(claimed_output)
}
/// Scatter Nd
pub(crate) fn scatter_nd<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 3],
) -> Result<ValTensor<F>, Box<dyn Error>> {
let (input, mut index, src) = (values[0].clone(), values[1].clone(), values[2].clone());
if !index.all_prev_assigned() {
index = region.assign(&config.custom_gates.inputs[1], &index)?;
region.increment(index.len());
}
let is_assigned = !input.any_unknowns()? && !index.any_unknowns()? && !src.any_unknowns()?;
let claimed_output: ValTensor<F> = if is_assigned {
let input_inner = input.get_int_evals()?;
let index_inner = index.get_int_evals()?.map(|x| x as usize);
let src_inner = src.get_int_evals()?;
let res = tensor::ops::scatter_nd(&input_inner, &index_inner, &src_inner)?;
res.iter()
.map(|x| Value::known(i128_to_felt(*x)))
.collect::<Tensor<Value<F>>>()
.into()
} else {
Tensor::new(
Some(&vec![Value::<F>::unknown(); input.len()]),
&[input.len()],
)?
.into()
};
// assign the claimed output
let mut claimed_output = region.assign(&config.custom_gates.output, &claimed_output)?;
region.increment(claimed_output.len());
claimed_output.reshape(input.dims())?;
// scatter elements is the inverse of gather elements
let (gather_src, linear_index) =
gather_nd(config, region, &[claimed_output.clone(), index.clone()], 0)?;
// assert this is equal to the src
enforce_equality(config, region, &[gather_src, src])?;
let full_index_set: ValTensor<F> =
Tensor::from((0..input.len() as u64).map(|x| ValType::Constant(F::from(x)))).into();
let input_indices = get_missing_set_elements(
config,
region,
&[linear_index, full_index_set.clone()],
true,
)?;
// now that it is flattened we can gather over elements on dim 0
claimed_output.flatten();
let (gather_input, _) = gather_elements(
config,
region,
&[claimed_output.clone(), input_indices.clone()],
0,
)?;
// assert this is a subset of the input
dynamic_lookup(
config,
region,
&[input_indices, gather_input],
&[full_index_set, input.clone()],
)?;
claimed_output.reshape(input.dims())?;
Ok(claimed_output)
}
/// sum accumulated layout
@@ -1443,17 +1751,11 @@ pub(crate) fn argmax_axes<F: PrimeField + TensorType + PartialOrd>(
dim: usize,
) -> Result<ValTensor<F>, Box<dyn Error>> {
// these will be assigned as constants
let mut indices = Tensor::from((0..values[0].dims()[dim] as u64).map(|x| F::from(x)));
indices.set_visibility(&crate::graph::Visibility::Fixed);
let indices = region.assign(&config.custom_gates.inputs[1], &indices.try_into()?)?;
region.increment(indices.len());
let argmax = move |config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1]|
-> Result<ValTensor<F>, Box<dyn Error>> {
argmax(config, region, values, indices.clone())
};
let argmax =
move |config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1]|
-> Result<ValTensor<F>, Box<dyn Error>> { argmax(config, region, values) };
// calculate value of output
axes_wise_op(config, region, values, &[dim], argmax)
@@ -1479,18 +1781,12 @@ pub(crate) fn argmin_axes<F: PrimeField + TensorType + PartialOrd>(
dim: usize,
) -> Result<ValTensor<F>, Box<dyn Error>> {
// calculate value of output
// these will be assigned as constants
let mut indices = Tensor::from((0..values[0].dims()[dim] as u64).map(|x| F::from(x)));
indices.set_visibility(&crate::graph::Visibility::Fixed);
let indices = region.assign(&config.custom_gates.inputs[1], &indices.try_into()?)?;
region.increment(indices.len());
let argmin = move |config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1]|
-> Result<ValTensor<F>, Box<dyn Error>> {
argmin(config, region, values, indices.clone())
};
let argmin =
move |config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1]|
-> Result<ValTensor<F>, Box<dyn Error>> { argmin(config, region, values) };
axes_wise_op(config, region, values, &[dim], argmin)
}
@@ -2665,7 +2961,6 @@ pub(crate) fn argmax<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
indices: ValTensor<F>,
) -> Result<ValTensor<F>, Box<dyn Error>> {
// this is safe because we later constrain it
let argmax = values[0]
@@ -2688,7 +2983,6 @@ pub(crate) fn argmax<F: PrimeField + TensorType + PartialOrd>(
config,
region,
&[values[0].clone(), assigned_argmax.clone()],
indices,
)?;
let max_val = max(config, region, &[values[0].clone()])?;
@@ -2703,7 +2997,6 @@ pub(crate) fn argmin<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
indices: ValTensor<F>,
) -> Result<ValTensor<F>, Box<dyn Error>> {
// this is safe because we later constrain it
let argmin = values[0]
@@ -2727,7 +3020,6 @@ pub(crate) fn argmin<F: PrimeField + TensorType + PartialOrd>(
config,
region,
&[values[0].clone(), assigned_argmin.clone()],
indices,
)?;
let min_val = min(config, region, &[values[0].clone()])?;

View File

@@ -14,10 +14,17 @@ pub enum PolyOp {
dim: usize,
constant_idx: Option<Tensor<usize>>,
},
GatherND {
batch_dims: usize,
indices: Option<Tensor<usize>>,
},
ScatterElements {
dim: usize,
constant_idx: Option<Tensor<usize>>,
},
ScatterND {
constant_idx: Option<Tensor<usize>>,
},
MultiBroadcastTo {
shape: Vec<usize>,
},
@@ -89,7 +96,9 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
fn as_string(&self) -> String {
match &self {
PolyOp::GatherElements { dim, .. } => format!("GATHERELEMENTS (dim={})", dim),
PolyOp::GatherND { batch_dims, .. } => format!("GATHERND (batch_dims={})", batch_dims),
PolyOp::ScatterElements { dim, .. } => format!("SCATTERELEMENTS (dim={})", dim),
PolyOp::ScatterND { .. } => "SCATTERND".into(),
PolyOp::MultiBroadcastTo { shape } => format!("MULTIBROADCASTTO (shape={:?})", shape),
PolyOp::MoveAxis { .. } => "MOVEAXIS".into(),
PolyOp::Downsample { .. } => "DOWNSAMPLE".into(),
@@ -213,6 +222,18 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
};
tensor::ops::gather_elements(&x, &y, *dim)
}
PolyOp::GatherND {
indices,
batch_dims,
} => {
let x = inputs[0].clone();
let y = if let Some(idx) = indices {
idx.clone()
} else {
inputs[1].clone().map(|x| felt_to_i128(x) as usize)
};
tensor::ops::gather_nd(&x, &y, *batch_dims)
}
PolyOp::ScatterElements { dim, constant_idx } => {
let x = inputs[0].clone();
@@ -229,6 +250,21 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
};
tensor::ops::scatter(&x, &idx, &src, *dim)
}
PolyOp::ScatterND { constant_idx } => {
let x = inputs[0].clone();
let idx = if let Some(idx) = constant_idx {
idx.clone()
} else {
inputs[1].clone().map(|x| felt_to_i128(x) as usize)
};
let src = if constant_idx.is_some() {
inputs[1].clone()
} else {
inputs[2].clone()
};
tensor::ops::scatter_nd(&x, &idx, &src)
}
}?;
Ok(ForwardResult { output: res })
@@ -276,7 +312,17 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
if let Some(idx) = constant_idx {
tensor::ops::gather_elements(values[0].get_inner_tensor()?, idx, *dim)?.into()
} else {
layouts::gather_elements(config, region, values[..].try_into()?, *dim)?
layouts::gather_elements(config, region, values[..].try_into()?, *dim)?.0
}
}
PolyOp::GatherND {
batch_dims,
indices,
} => {
if let Some(idx) = indices {
tensor::ops::gather_nd(values[0].get_inner_tensor()?, idx, *batch_dims)?.into()
} else {
layouts::gather_nd(config, region, values[..].try_into()?, *batch_dims)?.0
}
}
PolyOp::ScatterElements { dim, constant_idx } => {
@@ -292,6 +338,18 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
layouts::scatter_elements(config, region, values[..].try_into()?, *dim)?
}
}
PolyOp::ScatterND { constant_idx } => {
if let Some(idx) = constant_idx {
tensor::ops::scatter_nd(
values[0].get_inner_tensor()?,
idx,
values[1].get_inner_tensor()?,
)?
.into()
} else {
layouts::scatter_nd(config, region, values[..].try_into()?)?
}
}
PolyOp::DeConv {
padding,
output_padding,
@@ -389,7 +447,9 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
vec![1, 2]
} else if matches!(self, PolyOp::Concat { .. }) {
(0..100).collect()
} else if matches!(self, PolyOp::ScatterElements { .. }) {
} else if matches!(self, PolyOp::ScatterElements { .. })
| matches!(self, PolyOp::ScatterND { .. })
{
vec![0, 2]
} else {
vec![]

View File

@@ -402,9 +402,6 @@ pub enum Commands {
/// Number of logrows to use for srs. Overrides settings_path if specified.
#[arg(long, default_value = None)]
logrows: Option<u32>,
/// Check mode for SRS. Verifies downloaded srs is valid. Set to unsafe for speed.
#[arg(long, default_value = DEFAULT_CHECKMODE)]
check: CheckMode,
},
/// Loads model and input and runs mock prover (for testing)
Mock {

View File

@@ -159,8 +159,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
srs_path,
settings_path,
logrows,
check,
} => get_srs_cmd(srs_path, settings_path, logrows, check).await,
} => get_srs_cmd(srs_path, settings_path, logrows).await,
Commands::Table { model, args } => table(model, args),
#[cfg(feature = "render")]
Commands::RenderCircuit {
@@ -492,23 +491,28 @@ async fn fetch_srs(uri: &str) -> Result<Vec<u8>, Box<dyn Error>> {
}
#[cfg(not(target_arch = "wasm32"))]
fn check_srs_hash(logrows: u32, srs_path: Option<PathBuf>) -> Result<String, Box<dyn Error>> {
pub(crate) fn get_file_hash(path: &PathBuf) -> Result<String, Box<dyn Error>> {
use std::io::Read;
let path = get_srs_path(logrows, srs_path);
let file = std::fs::File::open(path.clone())?;
let file = std::fs::File::open(path)?;
let mut reader = std::io::BufReader::new(file);
let mut buffer = vec![];
let mut reader = std::io::BufReader::with_capacity(*EZKL_BUF_CAPACITY, file);
let bytes_read = reader.read_to_end(&mut buffer)?;
info!(
"read {} bytes from SRS file (vector of len = {})",
"read {} bytes from file (vector of len = {})",
bytes_read,
buffer.len()
);
let hash = sha256::digest(buffer);
info!("SRS hash: {}", hash);
info!("file hash: {}", hash);
Ok(hash)
}
#[cfg(not(target_arch = "wasm32"))]
fn check_srs_hash(logrows: u32, srs_path: Option<PathBuf>) -> Result<String, Box<dyn Error>> {
let path = get_srs_path(logrows, srs_path);
let hash = get_file_hash(&path)?;
let predefined_hash = match { crate::srs_sha::PUBLIC_SRS_SHA256_HASHES.get(&logrows) } {
Some(h) => h,
@@ -532,7 +536,6 @@ pub(crate) async fn get_srs_cmd(
srs_path: Option<PathBuf>,
settings_path: Option<PathBuf>,
logrows: Option<u32>,
check_mode: CheckMode,
) -> Result<String, Box<dyn Error>> {
// logrows overrides settings
@@ -560,21 +563,20 @@ pub(crate) async fn get_srs_cmd(
let srs_uri = format!("{}{}", PUBLIC_SRS_URL, k);
let mut reader = Cursor::new(fetch_srs(&srs_uri).await?);
// check the SRS
if matches!(check_mode, CheckMode::SAFE) {
#[cfg(not(target_arch = "wasm32"))]
let pb = init_spinner();
#[cfg(not(target_arch = "wasm32"))]
pb.set_message("Validating SRS (this may take a while) ...");
ParamsKZG::<Bn256>::read(&mut reader)?;
#[cfg(not(target_arch = "wasm32"))]
pb.finish_with_message("SRS validated");
}
#[cfg(not(target_arch = "wasm32"))]
let pb = init_spinner();
#[cfg(not(target_arch = "wasm32"))]
pb.set_message("Validating SRS (this may take a while) ...");
let params = ParamsKZG::<Bn256>::read(&mut reader)?;
#[cfg(not(target_arch = "wasm32"))]
pb.finish_with_message("SRS validated.");
info!("Saving SRS to disk...");
let mut file = std::fs::File::create(get_srs_path(k, srs_path.clone()))?;
let mut buffer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, &mut file);
buffer.write_all(reader.get_ref())?;
buffer.flush()?;
params.write(&mut buffer)?;
info!("Saved SRS to disk.");
info!("SRS downloaded");
} else {
@@ -969,8 +971,8 @@ pub(crate) fn calibrate(
continue;
}
}
// drop the gag
// drop the gag
#[cfg(unix)]
drop(_r);
#[cfg(unix)]
@@ -1695,7 +1697,7 @@ pub(crate) fn fuzz(
let logrows = circuit.settings().run_args.logrows;
info!("setting up tests");
#[cfg(unix)]
let _r = Gag::stdout()?;
let params = gen_srs::<KZGCommitmentScheme<Bn256>>(logrows);
@@ -1713,6 +1715,7 @@ pub(crate) fn fuzz(
let public_inputs = circuit.prepare_public_inputs(&data)?;
let strategy = KZGSingleStrategy::new(&params);
#[cfg(unix)]
std::mem::drop(_r);
info!("starting fuzzing");
@@ -1903,6 +1906,7 @@ pub(crate) fn run_fuzz_fn(
passed: &AtomicBool,
) {
let num_failures = AtomicI64::new(0);
#[cfg(unix)]
let _r = Gag::stdout().unwrap();
let pb = init_bar(num_runs as u64);
@@ -1916,6 +1920,7 @@ pub(crate) fn run_fuzz_fn(
pb.inc(1);
});
pb.finish_with_message("Done.");
#[cfg(unix)]
std::mem::drop(_r);
info!(
"num failures: {} out of {}",

View File

@@ -23,7 +23,10 @@ use std::sync::Arc;
use tract_onnx::prelude::{DatumType, Node as OnnxNode, TypedFact, TypedOp};
#[cfg(not(target_arch = "wasm32"))]
use tract_onnx::tract_core::ops::{
array::{Gather, GatherElements, MultiBroadcastTo, OneHot, ScatterElements, Slice, Topk},
array::{
Gather, GatherElements, GatherNd, MultiBroadcastTo, OneHot, ScatterElements, ScatterNd,
Slice, Topk,
},
change_axes::AxisOp,
cnn::{Conv, Deconv},
einsum::EinSum,
@@ -467,6 +470,78 @@ pub fn new_op_from_onnx(
// Extract the max value
}
"ScatterNd" => {
if inputs.len() != 3 {
return Err(Box::new(GraphError::InvalidDims(
idx,
"scatter nd".to_string(),
)));
};
// just verify it deserializes correctly
let _op = load_op::<ScatterNd>(node.op(), idx, node.op().name().to_string())?;
let mut op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::ScatterND {
constant_idx: None,
});
// if param_visibility.is_public() {
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
inputs[1].decrement_use();
deleted_indices.push(inputs.len() - 1);
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::ScatterND {
constant_idx: Some(c.raw_values.map(|x| x as usize)),
})
}
// }
if inputs[1].opkind().is_input() {
inputs[1].replace_opkind(SupportedOp::Input(crate::circuit::ops::Input {
scale: 0,
datum_type: InputType::TDim,
}));
inputs[1].bump_scale(0);
}
op
}
"GatherNd" => {
if inputs.len() != 2 {
return Err(Box::new(GraphError::InvalidDims(
idx,
"gather nd".to_string(),
)));
};
let op = load_op::<GatherNd>(node.op(), idx, node.op().name().to_string())?;
let batch_dims = op.batch_dims;
let mut op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::GatherND {
batch_dims,
indices: None,
});
// if param_visibility.is_public() {
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
inputs[1].decrement_use();
deleted_indices.push(inputs.len() - 1);
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::GatherND {
batch_dims,
indices: Some(c.raw_values.map(|x| x as usize)),
})
}
// }
if inputs[1].opkind().is_input() {
inputs[1].replace_opkind(SupportedOp::Input(crate::circuit::ops::Input {
scale: 0,
datum_type: InputType::TDim,
}));
inputs[1].bump_scale(0);
}
op
}
"GatherElements" => {
if inputs.len() != 2 {
return Err(Box::new(GraphError::InvalidDims(

View File

@@ -484,7 +484,6 @@ fn get_srs(
srs_path,
settings_path,
logrows,
CheckMode::SAFE,
))
.map_err(|e| {
let err_str = format!("Failed to get srs: {}", e);
@@ -1104,6 +1103,7 @@ fn ezkl(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_class::<PyG1Affine>()?;
m.add_class::<PyG1>()?;
m.add_class::<PyTestDataSource>()?;
m.add("__version__", env!("CARGO_PKG_VERSION"))?;
m.add_function(wrap_pyfunction!(felt_to_big_endian, m)?)?;
m.add_function(wrap_pyfunction!(felt_to_int, m)?)?;
m.add_function(wrap_pyfunction!(felt_to_float, m)?)?;

View File

@@ -673,6 +673,68 @@ impl<T: Clone + TensorType> Tensor<T> {
Tensor::new(Some(&res), &dims)
}
/// Set a slice of the Tensor.
/// ```
/// use ezkl::tensor::Tensor;
/// let mut a = Tensor::<i32>::new(Some(&[1, 2, 3, 4, 5, 6]), &[2, 3]).unwrap();
/// let b = Tensor::<i32>::new(Some(&[1, 2, 3, 1, 2, 3]), &[2, 3]).unwrap();
/// a.set_slice(&[1..2], &Tensor::<i32>::new(Some(&[1, 2, 3]), &[1, 3]).unwrap()).unwrap();
/// assert_eq!(a, b);
/// ```
pub fn set_slice(
&mut self,
indices: &[Range<usize>],
value: &Tensor<T>,
) -> Result<(), TensorError>
where
T: Send + Sync,
{
if indices.is_empty() {
return Ok(());
}
if self.dims.len() < indices.len() {
return Err(TensorError::DimError(format!(
"The dimensionality of the slice {:?} is greater than the tensor's {:?}",
indices, self.dims
)));
}
// if indices weren't specified we fill them in as required
let mut full_indices = indices.to_vec();
let omitted_dims = (indices.len()..self.dims.len())
.map(|i| self.dims[i])
.collect::<Vec<_>>();
for dim in &omitted_dims {
full_indices.push(0..*dim);
}
let full_dims = full_indices
.iter()
.map(|x| x.end - x.start)
.collect::<Vec<_>>();
// now broadcast the value to the full dims
let value = value.expand(&full_dims)?;
let cartesian_coord: Vec<Vec<usize>> = full_indices
.iter()
.cloned()
.multi_cartesian_product()
.collect();
let _ = cartesian_coord
.iter()
.enumerate()
.map(|(i, e)| {
self.set(e, value[i].clone());
})
.collect::<Vec<_>>();
Ok(())
}
/// Get the array index from rows / columns indices.
///
/// ```

View File

@@ -2,7 +2,7 @@ use super::TensorError;
use crate::tensor::{Tensor, TensorType};
use itertools::Itertools;
use maybe_rayon::{
iter::IndexedParallelIterator, iter::IntoParallelRefMutIterator, iter::ParallelIterator,
iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator},
prelude::IntoParallelRefIterator,
};
use std::collections::{HashMap, HashSet};
@@ -1328,6 +1328,316 @@ pub fn gather_elements<T: TensorType + Send + Sync>(
Ok(output)
}
/// Gather ND.
/// # Arguments
/// * `input` - Tensor
/// * `index` - Tensor of indices to gather
/// * `batch_dims` - Number of batch dimensions
/// # Examples
/// ```
/// use ezkl::tensor::Tensor;
/// use ezkl::tensor::ops::gather_nd;
/// let x = Tensor::<i128>::new(
/// Some(&[0, 1, 2, 3]),
/// &[2, 2],
/// ).unwrap();
/// let index = Tensor::<usize>::new(
/// Some(&[0, 0, 1, 1]),
/// &[2, 2],
/// ).unwrap();
/// let result = gather_nd(&x, &index, 0).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[0, 3]), &[2]).unwrap();
/// assert_eq!(result, expected);
///
/// let index = Tensor::<usize>::new(
/// Some(&[1, 0]),
/// &[2, 1],
/// ).unwrap();
/// let result = gather_nd(&x, &index, 0).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[2, 3, 0, 1]), &[2, 2]).unwrap();
/// assert_eq!(result, expected);
///
/// let x = Tensor::<i128>::new(
/// Some(&[0, 1, 2, 3, 4, 5, 6, 7]),
/// &[2, 2, 2],
/// ).unwrap();
/// let index = Tensor::<usize>::new(
/// Some(&[0, 1, 1, 0]),
/// &[2, 2],
/// ).unwrap();
/// let result = gather_nd(&x, &index, 0).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[2, 3, 4, 5]), &[2, 2]).unwrap();
/// assert_eq!(result, expected);
///
/// let index = Tensor::<usize>::new(
/// Some(&[0, 1, 1, 0]),
/// &[2, 1, 2],
/// ).unwrap();
/// let result = gather_nd(&x, &index, 0).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[2, 3, 4, 5]), &[2, 1, 2]).unwrap();
/// assert_eq!(result, expected);
///
/// let index = Tensor::<usize>::new(
/// Some(&[1, 0]),
/// &[2, 1],
/// ).unwrap();
/// let result = gather_nd(&x, &index, 1).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[2, 3, 4, 5]), &[2, 2]).unwrap();
/// assert_eq!(result, expected);
///
/// let index = Tensor::<usize>::new(
/// Some(&[0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1]),
/// &[2, 2, 3],
/// ).unwrap();
/// let result = gather_nd(&x, &index, 0).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[2, 3, 4, 5]), &[2, 2]).unwrap();
/// assert_eq!(result, expected);
///
/// let index = Tensor::<usize>::new(
/// Some(&[0, 1, 0, 0, 1, 1, 1, 0]),
/// &[2, 2, 2],
/// ).unwrap();
/// let result = gather_nd(&x, &index, 0).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[2, 3, 0, 1, 6, 7, 4, 5]), &[2, 2, 2]).unwrap();
/// assert_eq!(result, expected);
///
/// let index = Tensor::<usize>::new(
/// Some(&[0, 1, 0, 1, 1, 1]),
/// &[2, 3],
/// ).unwrap();
/// let result = gather_nd(&x, &index, 0).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[2, 7]), &[2]).unwrap();
/// assert_eq!(result, expected);
///
pub fn gather_nd<T: TensorType + Send + Sync>(
input: &Tensor<T>,
index: &Tensor<usize>,
batch_dims: usize,
) -> Result<Tensor<T>, TensorError> {
// Calculate the output tensor size
let index_dims = index.dims().to_vec();
let input_dims = input.dims().to_vec();
let last_value = index_dims
.last()
.ok_or(TensorError::DimMismatch("gather_nd".to_string()))?;
if last_value > &(input_dims.len() - batch_dims) {
return Err(TensorError::DimMismatch("gather_nd".to_string()));
}
let output_size =
// If indices_shape[-1] == r-b, since the rank of indices is q,
// indices can be thought of as N (q-b-1)-dimensional tensors containing 1-D tensors of dimension r-b,
// where N is an integer equals to the product of 1 and all the elements in the batch dimensions of the indices_shape.
// Let us think of each such r-b ranked tensor as indices_slice.
// Each scalar value corresponding to data[0:b-1,indices_slice] is filled into
// the corresponding location of the (q-b-1)-dimensional tensor to form the output tensor
// if indices_shape[-1] < r-b, since the rank of indices is q, indices can be thought of as N (q-b-1)-dimensional tensor containing 1-D tensors of dimension < r-b.
// Let us think of each such tensors as indices_slice.
// Each tensor slice corresponding to data[0:b-1, indices_slice , :] is filled into the corresponding location of the (q-b-1)-dimensional tensor to form the output tensor
{
let output_rank = input_dims.len() + index_dims.len() - 1 - batch_dims - last_value;
let mut dims = index_dims[..index_dims.len() - 1].to_vec();
let input_offset = batch_dims + last_value;
dims.extend(input_dims[input_offset..input_dims.len()].to_vec());
assert_eq!(output_rank, dims.len());
dims
};
// cartesian coord over batch dims
let mut batch_cartesian_coord = input_dims[0..batch_dims]
.iter()
.map(|x| 0..*x)
.multi_cartesian_product()
.collect::<Vec<_>>();
if batch_cartesian_coord.is_empty() {
batch_cartesian_coord.push(vec![]);
}
let outputs = batch_cartesian_coord
.par_iter()
.map(|batch_coord| {
let batch_slice = batch_coord.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
let mut index_slice = index.get_slice(&batch_slice)?;
index_slice.reshape(&index.dims()[batch_dims..])?;
let mut input_slice = input.get_slice(&batch_slice)?;
input_slice.reshape(&input.dims()[batch_dims..])?;
let mut inner_cartesian_coord = index_slice.dims()[0..index_slice.dims().len() - 1]
.iter()
.map(|x| 0..*x)
.multi_cartesian_product()
.collect::<Vec<_>>();
if inner_cartesian_coord.is_empty() {
inner_cartesian_coord.push(vec![]);
}
let output = inner_cartesian_coord
.iter()
.map(|coord| {
let slice = coord
.iter()
.map(|x| *x..*x + 1)
.chain(batch_coord.iter().map(|x| *x..*x + 1))
.collect::<Vec<_>>();
let index_slice = index_slice
.get_slice(&slice)
.unwrap()
.iter()
.map(|x| *x..*x + 1)
.collect::<Vec<_>>();
input_slice.get_slice(&index_slice).unwrap()
})
.collect::<Tensor<_>>();
output.combine()
})
.collect::<Result<Vec<_>, _>>()?;
let mut outputs = outputs.into_iter().flatten().collect::<Tensor<_>>();
outputs.reshape(&output_size)?;
Ok(outputs)
}
/// Scatter ND.
/// This operator is the inverse of GatherND.
/// # Arguments
/// * `input` - Tensor
/// * `index` - Tensor of indices to scatter
/// * `src` - Tensor of src
/// # Examples
/// ```
/// use ezkl::tensor::Tensor;
/// use ezkl::tensor::ops::scatter_nd;
/// let x = Tensor::<i128>::new(
/// Some(&[1, 2, 3, 4, 5, 6, 7, 8]),
/// &[8],
/// ).unwrap();
///
/// let index = Tensor::<usize>::new(
/// Some(&[4, 3, 1, 7]),
/// &[4, 1],
/// ).unwrap();
/// let src = Tensor::<i128>::new(
/// Some(&[9, 10, 11, 12]),
/// &[4],
/// ).unwrap();
/// let result = scatter_nd(&x, &index, &src).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[1, 11, 3, 10, 9, 6, 7, 12]), &[8]).unwrap();
/// assert_eq!(result, expected);
///
/// let x = Tensor::<i128>::new(
/// Some(&[1, 2, 3, 4, 5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1,
/// 1, 2, 3, 4, 5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1,
/// 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 7, 8,
/// 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 7, 8]),
/// &[4, 4, 4],
/// ).unwrap();
///
/// let index = Tensor::<usize>::new(
/// Some(&[0, 2]),
/// &[2, 1],
/// ).unwrap();
///
/// let src = Tensor::<i128>::new(
/// Some(&[5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8,
/// 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4,
/// ]),
/// &[2, 4, 4],
/// ).unwrap();
///
/// let result = scatter_nd(&x, &index, &src).unwrap();
///
/// let expected = Tensor::<i128>::new(
/// Some(&[5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8,
/// 1, 2, 3, 4, 5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1,
/// 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4,
/// 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 7, 8]),
/// &[4, 4, 4],
/// ).unwrap();
/// assert_eq!(result, expected);
///
/// let x = Tensor::<i128>::new(
/// Some(&[1, 2, 3, 4, 5, 6, 7, 8]),
/// &[2, 4],
/// ).unwrap();
///
/// let index = Tensor::<usize>::new(
/// Some(&[0, 1]),
/// &[2, 1],
/// ).unwrap();
/// let src = Tensor::<i128>::new(
/// Some(&[9, 10]),
/// &[2],
/// ).unwrap();
/// let result = scatter_nd(&x, &index, &src).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[9, 9, 9, 9, 10, 10, 10, 10]), &[2, 4]).unwrap();
/// assert_eq!(result, expected);
///
/// let x = Tensor::<i128>::new(
/// Some(&[1, 2, 3, 4, 5, 6, 7, 8]),
/// &[2, 4],
/// ).unwrap();
///
/// let index = Tensor::<usize>::new(
/// Some(&[0, 1]),
/// &[1, 1, 2],
/// ).unwrap();
/// let src = Tensor::<i128>::new(
/// Some(&[9]),
/// &[1, 1],
/// ).unwrap();
/// let result = scatter_nd(&x, &index, &src).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[1, 9, 3, 4, 5, 6, 7, 8]), &[2, 4]).unwrap();
/// assert_eq!(result, expected);
/// ````
///
pub fn scatter_nd<T: TensorType + Send + Sync>(
input: &Tensor<T>,
index: &Tensor<usize>,
src: &Tensor<T>,
) -> Result<Tensor<T>, TensorError> {
// Calculate the output tensor size
let index_dims = index.dims().to_vec();
let input_dims = input.dims().to_vec();
let last_value = index_dims
.last()
.ok_or(TensorError::DimMismatch("scatter_nd".to_string()))?;
if last_value > &input_dims.len() {
return Err(TensorError::DimMismatch("scatter_nd".to_string()));
}
let mut output = input.clone();
let cartesian_coord = index_dims[0..index_dims.len() - 1]
.iter()
.map(|x| 0..*x)
.multi_cartesian_product()
.collect::<Vec<_>>();
cartesian_coord
.iter()
.map(|coord| {
let slice = coord.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
let index_val = index.get_slice(&slice)?;
let index_slice = index_val.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
let src_val = src.get_slice(&slice)?;
output.set_slice(&index_slice, &src_val)?;
Ok(())
})
.collect::<Result<Vec<_>, _>>()?;
Ok(output)
}
fn axes_op<T: TensorType + Send + Sync>(
a: &Tensor<T>,
axes: &[usize],

View File

@@ -494,7 +494,12 @@ impl<F: PrimeField + TensorType + PartialOrd> ValTensor<F> {
}
_ => return Err(Box::new(TensorError::WrongMethod)),
};
Ok(integer_evals.into_iter().into())
let mut tensor: Tensor<i128> = integer_evals.into_iter().into();
match tensor.reshape(self.dims()) {
_ => {}
};
Ok(tensor)
}
/// Calls `pad_to_zero_rem` on the inner tensor.

View File

@@ -193,7 +193,7 @@ mod native_tests {
"1l_tiny_div",
];
const TESTS: [&str; 77] = [
const TESTS: [&str; 79] = [
"1l_mlp", //0
"1l_slice",
"1l_concat",
@@ -275,6 +275,8 @@ mod native_tests {
"ltsf",
"remainder", //75
"bitshift",
"gather_nd",
"scatter_nd",
];
const WASM_TESTS: [&str; 46] = [
@@ -502,7 +504,7 @@ mod native_tests {
}
});
seq!(N in 0..=76 {
seq!(N in 0..=78 {
#(#[test_case(TESTS[N])])*
#[ignore]
@@ -589,13 +591,16 @@ mod native_tests {
#(#[test_case(TESTS[N])])*
fn mock_large_batch_public_outputs_(test: &str) {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
let large_batch_dir = &format!("large_batches_{}", test);
crate::native_tests::mk_data_batches_(path, test, &large_batch_dir, 10);
mock(path, large_batch_dir.to_string(), "private", "private", "public", 10, "resources", None, 0.0);
test_dir.close().unwrap();
// currently variable output rank is not supported in ONNX
if test != "gather_nd" {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
let large_batch_dir = &format!("large_batches_{}", test);
crate::native_tests::mk_data_batches_(path, test, &large_batch_dir, 10);
mock(path, large_batch_dir.to_string(), "private", "private", "public", 10, "resources", None, 0.0);
test_dir.close().unwrap();
}
}
#(#[test_case(TESTS[N])])*

Binary file not shown.

Binary file not shown.

File diff suppressed because one or more lines are too long

Binary file not shown.