|
|
|
|
@@ -646,14 +646,13 @@ fn select<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
config: &BaseConfig<F>,
|
|
|
|
|
region: &mut RegionCtx<F>,
|
|
|
|
|
values: &[ValTensor<F>; 2],
|
|
|
|
|
dim_indices: ValTensor<F>,
|
|
|
|
|
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
|
|
|
|
let (mut input, index) = (values[0].clone(), values[1].clone());
|
|
|
|
|
input.flatten();
|
|
|
|
|
|
|
|
|
|
if !(dim_indices.all_prev_assigned() || region.is_dummy()) {
|
|
|
|
|
return Err("dim_indices must be assigned".into());
|
|
|
|
|
}
|
|
|
|
|
// these will be assigned as constants
|
|
|
|
|
let dim_indices: ValTensor<F> =
|
|
|
|
|
Tensor::from((0..input.len() as u64).map(|x| ValType::Constant(F::from(x)))).into();
|
|
|
|
|
|
|
|
|
|
let is_assigned = !input.any_unknowns()? && !index.any_unknowns()?;
|
|
|
|
|
|
|
|
|
|
@@ -929,91 +928,25 @@ pub(crate) fn gather<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
values: &[ValTensor<F>; 2],
|
|
|
|
|
dim: usize,
|
|
|
|
|
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
|
|
|
|
let (mut input, mut index_clone) = (values[0].clone(), values[1].clone());
|
|
|
|
|
let (input, mut index_clone) = (values[0].clone(), values[1].clone());
|
|
|
|
|
index_clone.flatten();
|
|
|
|
|
if index_clone.is_singleton() {
|
|
|
|
|
index_clone.reshape(&[1])?;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
let mut assigned_len = vec![];
|
|
|
|
|
if !input.all_prev_assigned() {
|
|
|
|
|
input = region.assign(&config.custom_gates.inputs[0], &input)?;
|
|
|
|
|
assigned_len.push(input.len());
|
|
|
|
|
}
|
|
|
|
|
if !index_clone.all_prev_assigned() {
|
|
|
|
|
index_clone = region.assign(&config.custom_gates.inputs[1], &index_clone)?;
|
|
|
|
|
assigned_len.push(index_clone.len());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if !assigned_len.is_empty() {
|
|
|
|
|
// safe to unwrap since we've just checked it has at least one element
|
|
|
|
|
region.increment(*assigned_len.iter().max().unwrap());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Calculate the output tensor size
|
|
|
|
|
let input_dims = input.dims();
|
|
|
|
|
let mut output_size = input_dims.to_vec();
|
|
|
|
|
|
|
|
|
|
output_size[dim] = index_clone.dims()[0];
|
|
|
|
|
|
|
|
|
|
// these will be assigned as constants
|
|
|
|
|
let mut indices = Tensor::from((0..input.dims()[dim] as u64).map(|x| F::from(x)));
|
|
|
|
|
indices.set_visibility(&crate::graph::Visibility::Fixed);
|
|
|
|
|
let indices = region.assign(&config.custom_gates.inputs[1], &indices.try_into()?)?;
|
|
|
|
|
region.increment(indices.len());
|
|
|
|
|
let linear_index =
|
|
|
|
|
linearize_element_index(config, region, &[index_clone], input_dims, dim, true)?;
|
|
|
|
|
|
|
|
|
|
let mut iteration_dims = output_size.clone();
|
|
|
|
|
iteration_dims[dim] = 1;
|
|
|
|
|
let mut output = select(config, region, &[input, linear_index])?;
|
|
|
|
|
|
|
|
|
|
// Allocate memory for the output tensor
|
|
|
|
|
let cartesian_coord = iteration_dims
|
|
|
|
|
.iter()
|
|
|
|
|
.map(|x| 0..*x)
|
|
|
|
|
.multi_cartesian_product()
|
|
|
|
|
.collect::<Vec<_>>();
|
|
|
|
|
|
|
|
|
|
let mut results = HashMap::new();
|
|
|
|
|
|
|
|
|
|
for coord in cartesian_coord {
|
|
|
|
|
let mut slice = coord.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
|
|
|
|
|
slice[dim] = 0..input_dims[dim];
|
|
|
|
|
|
|
|
|
|
let mut sliced_input = input.get_slice(&slice)?;
|
|
|
|
|
sliced_input.flatten();
|
|
|
|
|
|
|
|
|
|
let res = select(
|
|
|
|
|
config,
|
|
|
|
|
region,
|
|
|
|
|
&[sliced_input, index_clone.clone()],
|
|
|
|
|
indices.clone(),
|
|
|
|
|
)?;
|
|
|
|
|
|
|
|
|
|
results.insert(coord, res);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Allocate memory for the output tensor
|
|
|
|
|
let cartesian_coord = output_size
|
|
|
|
|
.iter()
|
|
|
|
|
.map(|x| 0..*x)
|
|
|
|
|
.multi_cartesian_product()
|
|
|
|
|
.collect::<Vec<_>>();
|
|
|
|
|
|
|
|
|
|
let mut output = Tensor::new(None, &output_size)?.par_enum_map(|i, _: ValType<F>| {
|
|
|
|
|
let coord = cartesian_coord[i].clone();
|
|
|
|
|
let mut key = coord.clone();
|
|
|
|
|
key[dim] = 0;
|
|
|
|
|
let result = &results.get(&key).ok_or("missing result")?;
|
|
|
|
|
let o = result.get_inner_tensor().map_err(|_| "missing tensor")?[coord[dim]].clone();
|
|
|
|
|
Ok::<ValType<F>, region::RegionError>(o)
|
|
|
|
|
})?;
|
|
|
|
|
|
|
|
|
|
// Reshape the output tensor
|
|
|
|
|
if index_clone.is_singleton() {
|
|
|
|
|
output_size.remove(dim);
|
|
|
|
|
}
|
|
|
|
|
output.reshape(&output_size)?;
|
|
|
|
|
|
|
|
|
|
Ok(output.into())
|
|
|
|
|
Ok(output)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Gather accumulated layout
|
|
|
|
|
@@ -1022,82 +955,173 @@ pub(crate) fn gather_elements<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
region: &mut RegionCtx<F>,
|
|
|
|
|
values: &[ValTensor<F>; 2],
|
|
|
|
|
dim: usize,
|
|
|
|
|
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
|
|
|
|
let (mut input, mut index) = (values[0].clone(), values[1].clone());
|
|
|
|
|
) -> Result<(ValTensor<F>, ValTensor<F>), Box<dyn Error>> {
|
|
|
|
|
let (input, index) = (values[0].clone(), values[1].clone());
|
|
|
|
|
|
|
|
|
|
assert_eq!(input.dims().len(), index.dims().len());
|
|
|
|
|
|
|
|
|
|
if !input.all_prev_assigned() {
|
|
|
|
|
input = region.assign(&config.custom_gates.inputs[0], &input)?;
|
|
|
|
|
}
|
|
|
|
|
if !index.all_prev_assigned() {
|
|
|
|
|
index = region.assign(&config.custom_gates.inputs[1], &index)?;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
region.increment(std::cmp::max(input.len(), index.len()));
|
|
|
|
|
|
|
|
|
|
// Calculate the output tensor size
|
|
|
|
|
let input_dims = input.dims();
|
|
|
|
|
let output_size = index.dims().to_vec();
|
|
|
|
|
|
|
|
|
|
// these will be assigned as constants
|
|
|
|
|
let mut indices = Tensor::from((0..input_dims[dim] as u64).map(|x| F::from(x)));
|
|
|
|
|
indices.set_visibility(&crate::graph::Visibility::Fixed);
|
|
|
|
|
let indices = region.assign(&config.custom_gates.inputs[1], &indices.try_into()?)?;
|
|
|
|
|
region.increment(indices.len());
|
|
|
|
|
let linear_index = linearize_element_index(config, region, &[index], input.dims(), dim, false)?;
|
|
|
|
|
|
|
|
|
|
let mut iteration_dims = output_size.clone();
|
|
|
|
|
iteration_dims[dim] = 1;
|
|
|
|
|
let mut output = select(config, region, &[input, linear_index.clone()])?;
|
|
|
|
|
|
|
|
|
|
output.reshape(&output_size)?;
|
|
|
|
|
|
|
|
|
|
Ok((output, linear_index))
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Takes a tensor representing a multi-dimensional index and returns a tensor representing the linearized index.
|
|
|
|
|
/// The linearized index is the index of the element in the flattened tensor.
|
|
|
|
|
/// FOr instance if the dims is [3,5,2], the linearized index of [2] at dim 1 is 2*5 + 3 = 13
|
|
|
|
|
pub(crate) fn linearize_element_index<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
config: &BaseConfig<F>,
|
|
|
|
|
region: &mut RegionCtx<F>,
|
|
|
|
|
values: &[ValTensor<F>; 1],
|
|
|
|
|
dims: &[usize],
|
|
|
|
|
dim: usize,
|
|
|
|
|
is_flat_index: bool,
|
|
|
|
|
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
|
|
|
|
let index = values[0].clone();
|
|
|
|
|
if !is_flat_index {
|
|
|
|
|
assert_eq!(index.dims().len(), dims.len());
|
|
|
|
|
// if the index is already flat, return it
|
|
|
|
|
if index.dims().len() == 1 {
|
|
|
|
|
return Ok(index);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
let dim_multiplier: Tensor<usize> = Tensor::new(None, &[dims.len()])?;
|
|
|
|
|
|
|
|
|
|
let dim_multiplier: Tensor<F> = dim_multiplier.par_enum_map(|i, _| {
|
|
|
|
|
let mut res = 1;
|
|
|
|
|
for dim in dims.iter().skip(i + 1) {
|
|
|
|
|
res *= dim;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Ok::<_, region::RegionError>(F::from(res as u64))
|
|
|
|
|
})?;
|
|
|
|
|
|
|
|
|
|
let iteration_dims = if is_flat_index {
|
|
|
|
|
let mut dims = dims.to_vec();
|
|
|
|
|
dims[dim] = index.len();
|
|
|
|
|
dims
|
|
|
|
|
} else {
|
|
|
|
|
index.dims().to_vec()
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// Allocate memory for the output tensor
|
|
|
|
|
let cartesian_coord = iteration_dims
|
|
|
|
|
.iter()
|
|
|
|
|
.map(|x| 0..*x)
|
|
|
|
|
.multi_cartesian_product()
|
|
|
|
|
.collect::<Vec<_>>();
|
|
|
|
|
|
|
|
|
|
let mut results = HashMap::new();
|
|
|
|
|
let val_dim_multiplier: ValTensor<F> = dim_multiplier
|
|
|
|
|
.get_slice(&[dim..dim + 1])?
|
|
|
|
|
.map(|x| ValType::Constant(x))
|
|
|
|
|
.into();
|
|
|
|
|
|
|
|
|
|
for coord in cartesian_coord {
|
|
|
|
|
let mut slice = coord.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
|
|
|
|
|
slice[dim] = 0..input_dims[dim];
|
|
|
|
|
let mut output = Tensor::new(None, &[cartesian_coord.len()])?;
|
|
|
|
|
|
|
|
|
|
let mut sliced_input = input.get_slice(&slice)?;
|
|
|
|
|
sliced_input.flatten();
|
|
|
|
|
let inner_loop_function = |i: usize, region: &mut RegionCtx<'_, F>| {
|
|
|
|
|
let coord = cartesian_coord[i].clone();
|
|
|
|
|
let slice: Vec<Range<usize>> = if is_flat_index {
|
|
|
|
|
coord[dim..dim + 1].iter().map(|x| *x..*x + 1).collect()
|
|
|
|
|
} else {
|
|
|
|
|
coord.iter().map(|x| *x..*x + 1).collect::<Vec<_>>()
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
slice[dim] = 0..output_size[dim];
|
|
|
|
|
let mut sliced_index = index.get_slice(&slice)?;
|
|
|
|
|
sliced_index.flatten();
|
|
|
|
|
let index_val = index.get_slice(&slice)?;
|
|
|
|
|
|
|
|
|
|
let res = select(
|
|
|
|
|
let mut const_offset = F::ZERO;
|
|
|
|
|
for i in 0..dims.len() {
|
|
|
|
|
if i != dim {
|
|
|
|
|
const_offset += F::from(coord[i] as u64) * dim_multiplier[i];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
let const_offset = create_constant_tensor(const_offset, 1);
|
|
|
|
|
|
|
|
|
|
let res = pairwise(
|
|
|
|
|
config,
|
|
|
|
|
region,
|
|
|
|
|
&[sliced_input, sliced_index],
|
|
|
|
|
indices.clone(),
|
|
|
|
|
&[index_val, val_dim_multiplier.clone()],
|
|
|
|
|
BaseOp::Mult,
|
|
|
|
|
)?;
|
|
|
|
|
|
|
|
|
|
results.insert(coord, res);
|
|
|
|
|
}
|
|
|
|
|
let res = pairwise(config, region, &[res, const_offset], BaseOp::Add)?;
|
|
|
|
|
|
|
|
|
|
// Allocate memory for the output tensor
|
|
|
|
|
let cartesian_coord = output_size
|
|
|
|
|
.iter()
|
|
|
|
|
.map(|x| 0..*x)
|
|
|
|
|
.multi_cartesian_product()
|
|
|
|
|
.collect::<Vec<_>>();
|
|
|
|
|
Ok(res.get_inner_tensor()?[0].clone())
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
let output = Tensor::new(None, &output_size)?.par_enum_map(|i, _: ValType<F>| {
|
|
|
|
|
let coord = cartesian_coord[i].clone();
|
|
|
|
|
let mut key = coord.clone();
|
|
|
|
|
key[dim] = 0;
|
|
|
|
|
let result = &results.get(&key).ok_or("missing result")?;
|
|
|
|
|
let o = result.get_inner_tensor().map_err(|_| "missing tensor")?[coord[dim]].clone();
|
|
|
|
|
Ok::<ValType<F>, region::RegionError>(o)
|
|
|
|
|
})?;
|
|
|
|
|
region.apply_in_loop(&mut output, inner_loop_function)?;
|
|
|
|
|
|
|
|
|
|
Ok(output.into())
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub(crate) fn get_missing_set_elements<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
config: &BaseConfig<F>,
|
|
|
|
|
region: &mut RegionCtx<F>,
|
|
|
|
|
values: &[ValTensor<F>; 2],
|
|
|
|
|
ordered: bool,
|
|
|
|
|
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
|
|
|
|
let (mut input, fullset) = (values[0].clone(), values[1].clone());
|
|
|
|
|
let set_len = fullset.len();
|
|
|
|
|
input.flatten();
|
|
|
|
|
|
|
|
|
|
let is_assigned = !input.any_unknowns()? && !fullset.any_unknowns()?;
|
|
|
|
|
|
|
|
|
|
let mut claimed_output: ValTensor<F> = if is_assigned {
|
|
|
|
|
let input_evals = input.get_int_evals()?;
|
|
|
|
|
let mut fullset_evals = fullset.get_int_evals()?.into_iter().collect::<Vec<_>>();
|
|
|
|
|
|
|
|
|
|
// get the difference between the two vectors
|
|
|
|
|
for eval in input_evals.iter() {
|
|
|
|
|
// delete first occurence of that value
|
|
|
|
|
if let Some(pos) = fullset_evals.iter().position(|x| x == eval) {
|
|
|
|
|
fullset_evals.remove(pos);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// if fullset + input is the same length, then input is a subset of fullset, else randomly delete elements, this is a patch for
|
|
|
|
|
// the fact that we can't have a tensor of unknowns when using constant during gen-settings
|
|
|
|
|
if fullset_evals.len() != set_len - input.len() {
|
|
|
|
|
fullset_evals.truncate(set_len - input.len());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fullset_evals
|
|
|
|
|
.iter()
|
|
|
|
|
.map(|x| Value::known(i128_to_felt(*x)))
|
|
|
|
|
.collect::<Tensor<Value<F>>>()
|
|
|
|
|
.into()
|
|
|
|
|
} else {
|
|
|
|
|
let dim = fullset.len() - input.len();
|
|
|
|
|
Tensor::new(Some(&vec![Value::<F>::unknown(); dim]), &[dim])?.into()
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// assign the claimed output
|
|
|
|
|
claimed_output = region.assign(&config.custom_gates.output, &claimed_output)?;
|
|
|
|
|
|
|
|
|
|
// input and claimed output should be the shuffles of fullset
|
|
|
|
|
// concatentate input and claimed output
|
|
|
|
|
let input_and_claimed_output = input.concat(claimed_output.clone())?;
|
|
|
|
|
|
|
|
|
|
// assert that this is a permutation/shuffle
|
|
|
|
|
shuffles(
|
|
|
|
|
config,
|
|
|
|
|
region,
|
|
|
|
|
&[input_and_claimed_output.clone()],
|
|
|
|
|
&[fullset.clone()],
|
|
|
|
|
)?;
|
|
|
|
|
|
|
|
|
|
if ordered {
|
|
|
|
|
// assert that the claimed output is sorted
|
|
|
|
|
claimed_output = _sort_ascending(config, region, &[claimed_output])?;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Ok(claimed_output)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Gather accumulated layout
|
|
|
|
|
pub(crate) fn scatter_elements<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
config: &BaseConfig<F>,
|
|
|
|
|
@@ -1105,88 +1129,81 @@ pub(crate) fn scatter_elements<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
values: &[ValTensor<F>; 3],
|
|
|
|
|
dim: usize,
|
|
|
|
|
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
|
|
|
|
let (mut input, mut index, mut src) = (values[0].clone(), values[1].clone(), values[2].clone());
|
|
|
|
|
let (input, mut index, src) = (values[0].clone(), values[1].clone(), values[2].clone());
|
|
|
|
|
|
|
|
|
|
assert_eq!(input.dims().len(), index.dims().len());
|
|
|
|
|
|
|
|
|
|
let mut assigned_len = vec![];
|
|
|
|
|
let input_dims = input.dims();
|
|
|
|
|
|
|
|
|
|
if !input.all_prev_assigned() {
|
|
|
|
|
input = region.assign(&config.custom_gates.inputs[0], &input)?;
|
|
|
|
|
assigned_len.push(input.len());
|
|
|
|
|
}
|
|
|
|
|
if !index.all_prev_assigned() {
|
|
|
|
|
index = region.assign(&config.custom_gates.inputs[1], &index)?;
|
|
|
|
|
assigned_len.push(index.len());
|
|
|
|
|
}
|
|
|
|
|
if !src.all_prev_assigned() {
|
|
|
|
|
src = region.assign(&config.custom_gates.output, &src)?;
|
|
|
|
|
assigned_len.push(src.len());
|
|
|
|
|
region.increment(index.len());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if !assigned_len.is_empty() {
|
|
|
|
|
// safe to unwrap since we've just checked it has at least one element
|
|
|
|
|
region.increment(*assigned_len.iter().max().unwrap());
|
|
|
|
|
}
|
|
|
|
|
let is_assigned = !input.any_unknowns()? && !index.any_unknowns()? && !src.any_unknowns()?;
|
|
|
|
|
|
|
|
|
|
// Calculate the output tensor size
|
|
|
|
|
let input_dim = input.dims()[dim];
|
|
|
|
|
let output_size = index.dims().to_vec();
|
|
|
|
|
let claimed_output: ValTensor<F> = if is_assigned {
|
|
|
|
|
let input_inner = input.get_int_evals()?;
|
|
|
|
|
let index_inner = index.get_int_evals()?.map(|x| x as usize);
|
|
|
|
|
let src_inner = src.get_int_evals()?;
|
|
|
|
|
|
|
|
|
|
// these will be assigned as constants
|
|
|
|
|
let mut indices = Tensor::from((0..input_dim as u64).map(|x| F::from(x)));
|
|
|
|
|
indices.set_visibility(&crate::graph::Visibility::Fixed);
|
|
|
|
|
let indices = region.assign(&config.custom_gates.inputs[1], &indices.try_into()?)?;
|
|
|
|
|
region.increment(indices.len());
|
|
|
|
|
let res = tensor::ops::scatter(&input_inner, &index_inner, &src_inner, dim)?;
|
|
|
|
|
|
|
|
|
|
// Allocate memory for the output tensor
|
|
|
|
|
let cartesian_coord = output_size
|
|
|
|
|
.iter()
|
|
|
|
|
.map(|x| 0..*x)
|
|
|
|
|
.multi_cartesian_product()
|
|
|
|
|
.collect::<Vec<_>>();
|
|
|
|
|
|
|
|
|
|
let mut output: Tensor<()> = Tensor::new(None, &output_size)?;
|
|
|
|
|
|
|
|
|
|
let mut inner_loop_function = |i: usize, region: &mut RegionCtx<'_, F>| {
|
|
|
|
|
let coord = cartesian_coord[i].clone();
|
|
|
|
|
let index_val = index.get_inner_tensor()?.get(&coord);
|
|
|
|
|
|
|
|
|
|
let src_val = src.get_inner_tensor()?.get(&coord);
|
|
|
|
|
let src_valtensor: ValTensor<F> = Tensor::from([src_val.clone()].into_iter()).into();
|
|
|
|
|
|
|
|
|
|
let mut slice = coord.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
|
|
|
|
|
slice[dim] = 0..input_dim;
|
|
|
|
|
|
|
|
|
|
let mut sliced_input = input.get_slice(&slice)?;
|
|
|
|
|
sliced_input.flatten();
|
|
|
|
|
|
|
|
|
|
let index_valtensor: ValTensor<F> = Tensor::from([index_val.clone()].into_iter()).into();
|
|
|
|
|
|
|
|
|
|
let mask = equals(config, region, &[index_valtensor, indices.clone()])?;
|
|
|
|
|
|
|
|
|
|
let res = iff(config, region, &[mask, src_valtensor, sliced_input])?;
|
|
|
|
|
|
|
|
|
|
let input_cartesian_coord = slice.into_iter().multi_cartesian_product();
|
|
|
|
|
let mutable_input_inner = input.get_inner_tensor_mut()?;
|
|
|
|
|
|
|
|
|
|
for (i, r) in res.get_inner_tensor()?.iter().enumerate() {
|
|
|
|
|
let coord = input_cartesian_coord
|
|
|
|
|
.clone()
|
|
|
|
|
.nth(i)
|
|
|
|
|
.ok_or("invalid coord")?;
|
|
|
|
|
*mutable_input_inner.get_mut(&coord) = r.clone();
|
|
|
|
|
}
|
|
|
|
|
Ok(())
|
|
|
|
|
res.iter()
|
|
|
|
|
.map(|x| Value::known(i128_to_felt(*x)))
|
|
|
|
|
.collect::<Tensor<Value<F>>>()
|
|
|
|
|
.into()
|
|
|
|
|
} else {
|
|
|
|
|
Tensor::new(
|
|
|
|
|
Some(&vec![Value::<F>::unknown(); input.len()]),
|
|
|
|
|
&[input.len()],
|
|
|
|
|
)?
|
|
|
|
|
.into()
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
output
|
|
|
|
|
.iter_mut()
|
|
|
|
|
.enumerate()
|
|
|
|
|
.map(|(i, _)| inner_loop_function(i, region))
|
|
|
|
|
.collect::<Result<Vec<()>, Box<dyn Error>>>()?;
|
|
|
|
|
// assign the claimed output
|
|
|
|
|
let mut claimed_output = region.assign(&config.custom_gates.output, &claimed_output)?;
|
|
|
|
|
region.increment(claimed_output.len());
|
|
|
|
|
claimed_output.reshape(input.dims())?;
|
|
|
|
|
|
|
|
|
|
Ok(input)
|
|
|
|
|
// scatter elements is the inverse of gather elements
|
|
|
|
|
let (gather_src, linear_index) = gather_elements(
|
|
|
|
|
config,
|
|
|
|
|
region,
|
|
|
|
|
&[claimed_output.clone(), index.clone()],
|
|
|
|
|
dim,
|
|
|
|
|
)?;
|
|
|
|
|
|
|
|
|
|
// assert this is equal to the src
|
|
|
|
|
enforce_equality(config, region, &[gather_src, src])?;
|
|
|
|
|
|
|
|
|
|
let full_index_set: ValTensor<F> =
|
|
|
|
|
Tensor::from((0..input.len() as u64).map(|x| ValType::Constant(F::from(x)))).into();
|
|
|
|
|
let input_indices = get_missing_set_elements(
|
|
|
|
|
config,
|
|
|
|
|
region,
|
|
|
|
|
&[linear_index, full_index_set.clone()],
|
|
|
|
|
true,
|
|
|
|
|
)?;
|
|
|
|
|
|
|
|
|
|
claimed_output.flatten();
|
|
|
|
|
let (gather_input, _) = gather_elements(
|
|
|
|
|
config,
|
|
|
|
|
region,
|
|
|
|
|
&[claimed_output.clone(), input_indices.clone()],
|
|
|
|
|
0,
|
|
|
|
|
)?;
|
|
|
|
|
// assert this is a subset of the input
|
|
|
|
|
dynamic_lookup(
|
|
|
|
|
config,
|
|
|
|
|
region,
|
|
|
|
|
&[input_indices, gather_input],
|
|
|
|
|
&[full_index_set, input.clone()],
|
|
|
|
|
)?;
|
|
|
|
|
|
|
|
|
|
claimed_output.reshape(input_dims)?;
|
|
|
|
|
|
|
|
|
|
Ok(claimed_output)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// sum accumulated layout
|
|
|
|
|
@@ -1443,17 +1460,11 @@ pub(crate) fn argmax_axes<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
dim: usize,
|
|
|
|
|
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
|
|
|
|
// these will be assigned as constants
|
|
|
|
|
let mut indices = Tensor::from((0..values[0].dims()[dim] as u64).map(|x| F::from(x)));
|
|
|
|
|
indices.set_visibility(&crate::graph::Visibility::Fixed);
|
|
|
|
|
let indices = region.assign(&config.custom_gates.inputs[1], &indices.try_into()?)?;
|
|
|
|
|
region.increment(indices.len());
|
|
|
|
|
|
|
|
|
|
let argmax = move |config: &BaseConfig<F>,
|
|
|
|
|
region: &mut RegionCtx<F>,
|
|
|
|
|
values: &[ValTensor<F>; 1]|
|
|
|
|
|
-> Result<ValTensor<F>, Box<dyn Error>> {
|
|
|
|
|
argmax(config, region, values, indices.clone())
|
|
|
|
|
};
|
|
|
|
|
let argmax =
|
|
|
|
|
move |config: &BaseConfig<F>,
|
|
|
|
|
region: &mut RegionCtx<F>,
|
|
|
|
|
values: &[ValTensor<F>; 1]|
|
|
|
|
|
-> Result<ValTensor<F>, Box<dyn Error>> { argmax(config, region, values) };
|
|
|
|
|
|
|
|
|
|
// calculate value of output
|
|
|
|
|
axes_wise_op(config, region, values, &[dim], argmax)
|
|
|
|
|
@@ -1479,18 +1490,12 @@ pub(crate) fn argmin_axes<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
dim: usize,
|
|
|
|
|
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
|
|
|
|
// calculate value of output
|
|
|
|
|
// these will be assigned as constants
|
|
|
|
|
let mut indices = Tensor::from((0..values[0].dims()[dim] as u64).map(|x| F::from(x)));
|
|
|
|
|
indices.set_visibility(&crate::graph::Visibility::Fixed);
|
|
|
|
|
let indices = region.assign(&config.custom_gates.inputs[1], &indices.try_into()?)?;
|
|
|
|
|
region.increment(indices.len());
|
|
|
|
|
|
|
|
|
|
let argmin = move |config: &BaseConfig<F>,
|
|
|
|
|
region: &mut RegionCtx<F>,
|
|
|
|
|
values: &[ValTensor<F>; 1]|
|
|
|
|
|
-> Result<ValTensor<F>, Box<dyn Error>> {
|
|
|
|
|
argmin(config, region, values, indices.clone())
|
|
|
|
|
};
|
|
|
|
|
let argmin =
|
|
|
|
|
move |config: &BaseConfig<F>,
|
|
|
|
|
region: &mut RegionCtx<F>,
|
|
|
|
|
values: &[ValTensor<F>; 1]|
|
|
|
|
|
-> Result<ValTensor<F>, Box<dyn Error>> { argmin(config, region, values) };
|
|
|
|
|
|
|
|
|
|
axes_wise_op(config, region, values, &[dim], argmin)
|
|
|
|
|
}
|
|
|
|
|
@@ -2665,7 +2670,6 @@ pub(crate) fn argmax<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
config: &BaseConfig<F>,
|
|
|
|
|
region: &mut RegionCtx<F>,
|
|
|
|
|
values: &[ValTensor<F>; 1],
|
|
|
|
|
indices: ValTensor<F>,
|
|
|
|
|
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
|
|
|
|
// this is safe because we later constrain it
|
|
|
|
|
let argmax = values[0]
|
|
|
|
|
@@ -2688,7 +2692,6 @@ pub(crate) fn argmax<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
config,
|
|
|
|
|
region,
|
|
|
|
|
&[values[0].clone(), assigned_argmax.clone()],
|
|
|
|
|
indices,
|
|
|
|
|
)?;
|
|
|
|
|
|
|
|
|
|
let max_val = max(config, region, &[values[0].clone()])?;
|
|
|
|
|
@@ -2703,7 +2706,6 @@ pub(crate) fn argmin<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
config: &BaseConfig<F>,
|
|
|
|
|
region: &mut RegionCtx<F>,
|
|
|
|
|
values: &[ValTensor<F>; 1],
|
|
|
|
|
indices: ValTensor<F>,
|
|
|
|
|
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
|
|
|
|
// this is safe because we later constrain it
|
|
|
|
|
let argmin = values[0]
|
|
|
|
|
@@ -2727,7 +2729,6 @@ pub(crate) fn argmin<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
config,
|
|
|
|
|
region,
|
|
|
|
|
&[values[0].clone(), assigned_argmin.clone()],
|
|
|
|
|
indices,
|
|
|
|
|
)?;
|
|
|
|
|
let min_val = min(config, region, &[values[0].clone()])?;
|
|
|
|
|
|
|
|
|
|
|