|
|
|
|
@@ -13,6 +13,8 @@ use maybe_rayon::{
|
|
|
|
|
slice::ParallelSliceMut,
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
use self::tensor::{create_constant_tensor, create_zero_tensor};
|
|
|
|
|
|
|
|
|
|
use super::{
|
|
|
|
|
chip::{BaseConfig, CircuitError},
|
|
|
|
|
region::RegionCtx,
|
|
|
|
|
@@ -21,7 +23,7 @@ use crate::{
|
|
|
|
|
circuit::{ops::base::BaseOp, utils},
|
|
|
|
|
fieldutils::{felt_to_i128, i128_to_felt},
|
|
|
|
|
tensor::{
|
|
|
|
|
get_broadcasted_shape,
|
|
|
|
|
create_unit_tensor, get_broadcasted_shape,
|
|
|
|
|
ops::{accumulated, add, mult, sub},
|
|
|
|
|
Tensor, TensorError, ValType,
|
|
|
|
|
},
|
|
|
|
|
@@ -30,29 +32,8 @@ use crate::{
|
|
|
|
|
use super::*;
|
|
|
|
|
use crate::circuit::ops::lookup::LookupOp;
|
|
|
|
|
|
|
|
|
|
///
|
|
|
|
|
pub fn overflowed_len(starting_idx: usize, mut total_len: usize, column_len: usize) -> usize {
|
|
|
|
|
let mut idx = starting_idx;
|
|
|
|
|
// let x = idx / column_len;
|
|
|
|
|
let y = idx % column_len;
|
|
|
|
|
if y + total_len < column_len {
|
|
|
|
|
return total_len;
|
|
|
|
|
}
|
|
|
|
|
// fill up first column
|
|
|
|
|
idx += column_len - y;
|
|
|
|
|
total_len += 1;
|
|
|
|
|
loop {
|
|
|
|
|
if idx >= starting_idx + total_len {
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
idx += column_len;
|
|
|
|
|
total_len += 1;
|
|
|
|
|
}
|
|
|
|
|
total_len
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Same as div but splits the division into N parts
|
|
|
|
|
pub fn loop_div<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
pub(crate) fn loop_div<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
config: &BaseConfig<F>,
|
|
|
|
|
region: &mut RegionCtx<F>,
|
|
|
|
|
value: &[ValTensor<F>; 1],
|
|
|
|
|
@@ -87,7 +68,7 @@ pub fn loop_div<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Div accumulated layout
|
|
|
|
|
pub fn div<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
pub(crate) fn div<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
config: &BaseConfig<F>,
|
|
|
|
|
region: &mut RegionCtx<F>,
|
|
|
|
|
value: &[ValTensor<F>; 1],
|
|
|
|
|
@@ -102,9 +83,9 @@ pub fn div<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
|
|
|
|
|
let range_check_bracket = felt_to_i128(div) / 2;
|
|
|
|
|
|
|
|
|
|
let mut divisor = Tensor::from(vec![ValType::Constant(div)].into_iter());
|
|
|
|
|
divisor.set_visibility(&crate::graph::Visibility::Fixed);
|
|
|
|
|
let divisor = region.assign(&config.custom_gates.inputs[1], &divisor.into())?;
|
|
|
|
|
let divisor = create_constant_tensor(div, 1);
|
|
|
|
|
|
|
|
|
|
let divisor = region.assign(&config.custom_gates.inputs[1], &divisor)?;
|
|
|
|
|
region.increment(divisor.len());
|
|
|
|
|
|
|
|
|
|
let is_assigned = !input.any_unknowns()? && !divisor.any_unknowns()?;
|
|
|
|
|
@@ -161,13 +142,10 @@ fn recip_int<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
// get values where input is 0
|
|
|
|
|
let zero_mask = equals_zero(config, region, input)?;
|
|
|
|
|
|
|
|
|
|
let one_minus_zero_mask = pairwise(
|
|
|
|
|
let zero_mask_minus_one = pairwise(
|
|
|
|
|
config,
|
|
|
|
|
region,
|
|
|
|
|
&[
|
|
|
|
|
zero_mask.clone(),
|
|
|
|
|
ValTensor::from(Tensor::from([ValType::Constant(F::ONE)].into_iter())),
|
|
|
|
|
],
|
|
|
|
|
&[zero_mask.clone(), create_unit_tensor(1)],
|
|
|
|
|
BaseOp::Sub,
|
|
|
|
|
)?;
|
|
|
|
|
|
|
|
|
|
@@ -176,9 +154,7 @@ fn recip_int<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
region,
|
|
|
|
|
&[
|
|
|
|
|
zero_mask,
|
|
|
|
|
ValTensor::from(Tensor::from(
|
|
|
|
|
[ValType::Constant(i128_to_felt(zero_inverse_val))].into_iter(),
|
|
|
|
|
)),
|
|
|
|
|
create_constant_tensor(i128_to_felt(zero_inverse_val), 1),
|
|
|
|
|
],
|
|
|
|
|
BaseOp::Mult,
|
|
|
|
|
)?;
|
|
|
|
|
@@ -186,13 +162,13 @@ fn recip_int<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
pairwise(
|
|
|
|
|
config,
|
|
|
|
|
region,
|
|
|
|
|
&[one_minus_zero_mask, zero_inverse_val],
|
|
|
|
|
&[zero_mask_minus_one, zero_inverse_val],
|
|
|
|
|
BaseOp::Add,
|
|
|
|
|
)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// recip accumulated layout
|
|
|
|
|
pub fn recip<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
pub(crate) fn recip<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
config: &BaseConfig<F>,
|
|
|
|
|
region: &mut RegionCtx<F>,
|
|
|
|
|
value: &[ValTensor<F>; 1],
|
|
|
|
|
@@ -254,15 +230,14 @@ pub fn recip<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
|
|
|
|
|
let zero_inverse_val =
|
|
|
|
|
tensor::ops::nonlinearities::zero_recip(felt_to_i128(output_scale) as f64)[0];
|
|
|
|
|
let zero_inverse =
|
|
|
|
|
Tensor::from([ValType::Constant(i128_to_felt::<F>(zero_inverse_val))].into_iter());
|
|
|
|
|
let zero_inverse = create_constant_tensor(i128_to_felt(zero_inverse_val), 1);
|
|
|
|
|
|
|
|
|
|
let equal_zero_mask = equals_zero(config, region, &[input.clone()])?;
|
|
|
|
|
|
|
|
|
|
let equal_inverse_mask = equals(
|
|
|
|
|
config,
|
|
|
|
|
region,
|
|
|
|
|
&[claimed_output.clone(), zero_inverse.into()],
|
|
|
|
|
&[claimed_output.clone(), zero_inverse],
|
|
|
|
|
)?;
|
|
|
|
|
|
|
|
|
|
// assert the two masks are equal
|
|
|
|
|
@@ -272,12 +247,12 @@ pub fn recip<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
&[equal_zero_mask.clone(), equal_inverse_mask],
|
|
|
|
|
)?;
|
|
|
|
|
|
|
|
|
|
let unit_scale = Tensor::from([ValType::Constant(i128_to_felt(range_check_len))].into_iter());
|
|
|
|
|
let unit_scale = create_constant_tensor(i128_to_felt(range_check_len), 1);
|
|
|
|
|
|
|
|
|
|
let unit_mask = pairwise(
|
|
|
|
|
config,
|
|
|
|
|
region,
|
|
|
|
|
&[equal_zero_mask, unit_scale.into()],
|
|
|
|
|
&[equal_zero_mask, unit_scale],
|
|
|
|
|
BaseOp::Mult,
|
|
|
|
|
)?;
|
|
|
|
|
|
|
|
|
|
@@ -296,7 +271,7 @@ pub fn recip<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Dot product accumulated layout
|
|
|
|
|
pub fn dot<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
pub(crate) fn dot<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
config: &BaseConfig<F>,
|
|
|
|
|
region: &mut RegionCtx<F>,
|
|
|
|
|
values: &[ValTensor<F>; 2],
|
|
|
|
|
@@ -327,7 +302,7 @@ pub fn dot<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
|
|
|
|
|
// if empty return a const
|
|
|
|
|
if values[0].is_empty() && values[1].is_empty() {
|
|
|
|
|
return Ok(Tensor::from([ValType::Constant(F::ZERO)].into_iter()).into());
|
|
|
|
|
return Ok(create_zero_tensor(1));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
let start = instant::Instant::now();
|
|
|
|
|
@@ -407,7 +382,7 @@ pub fn dot<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Einsum
|
|
|
|
|
pub fn einsum<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
pub(crate) fn einsum<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
config: &BaseConfig<F>,
|
|
|
|
|
region: &mut RegionCtx<F>,
|
|
|
|
|
inputs: &[ValTensor<F>],
|
|
|
|
|
@@ -659,32 +634,16 @@ fn _sort_ascending<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
|
|
|
|
|
let assigned_sort = region.assign(&config.custom_gates.inputs[0], &sorted.into())?;
|
|
|
|
|
|
|
|
|
|
let mut unit = Tensor::from(vec![F::from(1)].into_iter());
|
|
|
|
|
unit.set_visibility(&crate::graph::Visibility::Fixed);
|
|
|
|
|
let unit = region.assign(&config.custom_gates.inputs[1], &unit.try_into()?)?;
|
|
|
|
|
|
|
|
|
|
region.increment(assigned_sort.len());
|
|
|
|
|
|
|
|
|
|
for i in 0..assigned_sort.len() - 1 {
|
|
|
|
|
// assert that each thing in turn is larger than the next
|
|
|
|
|
let window_a = assigned_sort.get_slice(&[i..i + 1])?;
|
|
|
|
|
let window_b = assigned_sort.get_slice(&[i + 1..i + 2])?;
|
|
|
|
|
let window_a = assigned_sort.get_slice(&[0..assigned_sort.len() - 1])?;
|
|
|
|
|
let window_b = assigned_sort.get_slice(&[1..assigned_sort.len()])?;
|
|
|
|
|
|
|
|
|
|
let diff = pairwise(
|
|
|
|
|
config,
|
|
|
|
|
region,
|
|
|
|
|
&[window_b.clone(), window_a.clone()],
|
|
|
|
|
BaseOp::Sub,
|
|
|
|
|
)?;
|
|
|
|
|
let greater_than = nonlinearity(
|
|
|
|
|
config,
|
|
|
|
|
region,
|
|
|
|
|
&[diff],
|
|
|
|
|
&LookupOp::GreaterThanEqual { a: 0.0.into() },
|
|
|
|
|
)?;
|
|
|
|
|
let is_greater = greater_equal(config, region, &[window_b.clone(), window_a.clone()])?;
|
|
|
|
|
|
|
|
|
|
enforce_equality(config, region, &[unit.clone(), greater_than.clone()])?;
|
|
|
|
|
}
|
|
|
|
|
let unit = create_unit_tensor(is_greater.len());
|
|
|
|
|
|
|
|
|
|
enforce_equality(config, region, &[unit, is_greater])?;
|
|
|
|
|
|
|
|
|
|
// assert that this is a permutation/shuffle
|
|
|
|
|
shuffles(config, region, &[assigned_sort.clone()], &[input.clone()])?;
|
|
|
|
|
@@ -708,7 +667,7 @@ fn _select_topk<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Select top k elements
|
|
|
|
|
pub fn topk_axes<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
pub(crate) fn topk_axes<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
config: &BaseConfig<F>,
|
|
|
|
|
region: &mut RegionCtx<F>,
|
|
|
|
|
values: &[ValTensor<F>; 1],
|
|
|
|
|
@@ -798,9 +757,7 @@ fn one_hot<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
|
|
|
|
|
let sum = sum(config, region, &[assigned_output.clone()])?;
|
|
|
|
|
// assert sum is 1
|
|
|
|
|
let mut unit = Tensor::from(vec![F::from(1)].into_iter());
|
|
|
|
|
unit.set_visibility(&crate::graph::Visibility::Fixed);
|
|
|
|
|
let unit: ValTensor<F> = unit.try_into()?;
|
|
|
|
|
let unit = create_unit_tensor(1);
|
|
|
|
|
|
|
|
|
|
enforce_equality(config, region, &[unit.clone(), sum])?;
|
|
|
|
|
|
|
|
|
|
@@ -817,7 +774,7 @@ fn one_hot<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Dynamic lookup
|
|
|
|
|
pub fn dynamic_lookup<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
pub(crate) fn dynamic_lookup<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
config: &BaseConfig<F>,
|
|
|
|
|
region: &mut RegionCtx<F>,
|
|
|
|
|
lookups: &[ValTensor<F>; 2],
|
|
|
|
|
@@ -843,23 +800,18 @@ pub fn dynamic_lookup<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
let table_len = table_0.len();
|
|
|
|
|
|
|
|
|
|
// now create a vartensor of constants for the dynamic lookup index
|
|
|
|
|
let mut table_index = Tensor::from(
|
|
|
|
|
vec![ValType::Constant(F::from(dynamic_lookup_index as u64)); table_len].into_iter(),
|
|
|
|
|
);
|
|
|
|
|
table_index.set_visibility(&crate::graph::Visibility::Fixed);
|
|
|
|
|
let table_index = create_constant_tensor(F::from(dynamic_lookup_index as u64), table_len);
|
|
|
|
|
let _table_index =
|
|
|
|
|
region.assign_dynamic_lookup(&config.dynamic_lookups.tables[2], &table_index.into())?;
|
|
|
|
|
region.assign_dynamic_lookup(&config.dynamic_lookups.tables[2], &table_index)?;
|
|
|
|
|
|
|
|
|
|
let lookup_0 = region.assign(&config.dynamic_lookups.inputs[0], &lookup_0)?;
|
|
|
|
|
let lookup_1 = region.assign(&config.dynamic_lookups.inputs[1], &lookup_1)?;
|
|
|
|
|
let lookup_len = lookup_0.len();
|
|
|
|
|
|
|
|
|
|
// now set the lookup index
|
|
|
|
|
let mut lookup_index = Tensor::from(
|
|
|
|
|
vec![ValType::Constant(F::from(dynamic_lookup_index as u64)); lookup_len].into_iter(),
|
|
|
|
|
);
|
|
|
|
|
lookup_index.set_visibility(&crate::graph::Visibility::Fixed);
|
|
|
|
|
let _lookup_index = region.assign(&config.dynamic_lookups.inputs[2], &lookup_index.into())?;
|
|
|
|
|
let lookup_index = create_constant_tensor(F::from(dynamic_lookup_index as u64), lookup_len);
|
|
|
|
|
|
|
|
|
|
let _lookup_index = region.assign(&config.dynamic_lookups.inputs[2], &lookup_index)?;
|
|
|
|
|
|
|
|
|
|
if !region.is_dummy() {
|
|
|
|
|
(0..table_len)
|
|
|
|
|
@@ -900,7 +852,7 @@ pub fn dynamic_lookup<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Shuffle arg
|
|
|
|
|
pub fn shuffles<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
pub(crate) fn shuffles<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
config: &BaseConfig<F>,
|
|
|
|
|
region: &mut RegionCtx<F>,
|
|
|
|
|
input: &[ValTensor<F>; 1],
|
|
|
|
|
@@ -918,11 +870,8 @@ pub fn shuffles<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
let reference_len = reference.len();
|
|
|
|
|
|
|
|
|
|
// now create a vartensor of constants for the shuffle index
|
|
|
|
|
let mut index = Tensor::from(
|
|
|
|
|
vec![ValType::Constant(F::from(shuffle_index as u64)); reference_len].into_iter(),
|
|
|
|
|
);
|
|
|
|
|
index.set_visibility(&crate::graph::Visibility::Fixed);
|
|
|
|
|
let index = region.assign_shuffle(&config.shuffles.references[1], &index.into())?;
|
|
|
|
|
let index = create_constant_tensor(F::from(shuffle_index as u64), reference_len);
|
|
|
|
|
let index = region.assign_shuffle(&config.shuffles.references[1], &index)?;
|
|
|
|
|
|
|
|
|
|
let input = region.assign(&config.shuffles.inputs[0], &input)?;
|
|
|
|
|
region.assign(&config.shuffles.inputs[1], &index)?;
|
|
|
|
|
@@ -966,7 +915,7 @@ pub fn shuffles<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// One hot accumulated layout
|
|
|
|
|
pub fn one_hot_axis<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
pub(crate) fn one_hot_axis<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
config: &BaseConfig<F>,
|
|
|
|
|
region: &mut RegionCtx<F>,
|
|
|
|
|
values: &[ValTensor<F>; 1],
|
|
|
|
|
@@ -1019,7 +968,7 @@ pub fn one_hot_axis<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Gather accumulated layout
|
|
|
|
|
pub fn gather<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
pub(crate) fn gather<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
config: &BaseConfig<F>,
|
|
|
|
|
region: &mut RegionCtx<F>,
|
|
|
|
|
values: &[ValTensor<F>; 2],
|
|
|
|
|
@@ -1113,7 +1062,7 @@ pub fn gather<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Gather accumulated layout
|
|
|
|
|
pub fn gather_elements<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
pub(crate) fn gather_elements<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
config: &BaseConfig<F>,
|
|
|
|
|
region: &mut RegionCtx<F>,
|
|
|
|
|
values: &[ValTensor<F>; 2],
|
|
|
|
|
@@ -1133,15 +1082,48 @@ pub fn gather_elements<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
region.increment(std::cmp::max(input.len(), index.len()));
|
|
|
|
|
|
|
|
|
|
// Calculate the output tensor size
|
|
|
|
|
let input_dim = input.dims()[dim];
|
|
|
|
|
let input_dims = input.dims();
|
|
|
|
|
let output_size = index.dims().to_vec();
|
|
|
|
|
|
|
|
|
|
// these will be assigned as constants
|
|
|
|
|
let mut indices = Tensor::from((0..input_dim as u64).map(|x| F::from(x)));
|
|
|
|
|
let mut indices = Tensor::from((0..input_dims[dim] as u64).map(|x| F::from(x)));
|
|
|
|
|
indices.set_visibility(&crate::graph::Visibility::Fixed);
|
|
|
|
|
let indices = region.assign(&config.custom_gates.inputs[1], &indices.try_into()?)?;
|
|
|
|
|
region.increment(indices.len());
|
|
|
|
|
|
|
|
|
|
let mut iteration_dims = output_size.clone();
|
|
|
|
|
iteration_dims[dim] = 1;
|
|
|
|
|
|
|
|
|
|
// Allocate memory for the output tensor
|
|
|
|
|
let cartesian_coord = iteration_dims
|
|
|
|
|
.iter()
|
|
|
|
|
.map(|x| 0..*x)
|
|
|
|
|
.multi_cartesian_product()
|
|
|
|
|
.collect::<Vec<_>>();
|
|
|
|
|
|
|
|
|
|
let mut results = HashMap::new();
|
|
|
|
|
|
|
|
|
|
for coord in cartesian_coord {
|
|
|
|
|
let mut slice = coord.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
|
|
|
|
|
slice[dim] = 0..input_dims[dim];
|
|
|
|
|
|
|
|
|
|
let mut sliced_input = input.get_slice(&slice)?;
|
|
|
|
|
sliced_input.flatten();
|
|
|
|
|
|
|
|
|
|
slice[dim] = 0..output_size[dim];
|
|
|
|
|
let mut sliced_index = index.get_slice(&slice)?;
|
|
|
|
|
sliced_index.flatten();
|
|
|
|
|
|
|
|
|
|
let res = select(
|
|
|
|
|
config,
|
|
|
|
|
region,
|
|
|
|
|
&[sliced_input, sliced_index],
|
|
|
|
|
indices.clone(),
|
|
|
|
|
)?;
|
|
|
|
|
|
|
|
|
|
results.insert(coord, res);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Allocate memory for the output tensor
|
|
|
|
|
let cartesian_coord = output_size
|
|
|
|
|
.iter()
|
|
|
|
|
@@ -1149,39 +1131,20 @@ pub fn gather_elements<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
.multi_cartesian_product()
|
|
|
|
|
.collect::<Vec<_>>();
|
|
|
|
|
|
|
|
|
|
let mut output = Tensor::new(None, &output_size)?;
|
|
|
|
|
|
|
|
|
|
let inner_loop_function = |i: usize, region: &mut RegionCtx<'_, F>| {
|
|
|
|
|
let output = Tensor::new(None, &output_size)?.par_enum_map(|i, _: ValType<F>| {
|
|
|
|
|
let coord = cartesian_coord[i].clone();
|
|
|
|
|
let index_val = index.get_inner_tensor()?.get(&coord);
|
|
|
|
|
|
|
|
|
|
let mut slice = coord.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
|
|
|
|
|
slice[dim] = 0..input_dim;
|
|
|
|
|
|
|
|
|
|
let mut sliced_input = input.get_slice(&slice)?;
|
|
|
|
|
sliced_input.flatten();
|
|
|
|
|
|
|
|
|
|
let index_valtensor: ValTensor<F> = Tensor::from([index_val.clone()].into_iter()).into();
|
|
|
|
|
|
|
|
|
|
let res = select(
|
|
|
|
|
config,
|
|
|
|
|
region,
|
|
|
|
|
&[sliced_input, index_valtensor],
|
|
|
|
|
indices.clone(),
|
|
|
|
|
)?;
|
|
|
|
|
|
|
|
|
|
let res = res.get_inner_tensor()?;
|
|
|
|
|
|
|
|
|
|
Ok(res[0].clone())
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
region.apply_in_loop(&mut output, inner_loop_function)?;
|
|
|
|
|
let mut key = coord.clone();
|
|
|
|
|
key[dim] = 0;
|
|
|
|
|
let result = &results.get(&key).ok_or("missing result")?;
|
|
|
|
|
let o = result.get_inner_tensor().map_err(|_| "missing tensor")?[coord[dim]].clone();
|
|
|
|
|
Ok::<ValType<F>, region::RegionError>(o)
|
|
|
|
|
})?;
|
|
|
|
|
|
|
|
|
|
Ok(output.into())
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Gather accumulated layout
|
|
|
|
|
pub fn scatter_elements<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
pub(crate) fn scatter_elements<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
config: &BaseConfig<F>,
|
|
|
|
|
region: &mut RegionCtx<F>,
|
|
|
|
|
values: &[ValTensor<F>; 3],
|
|
|
|
|
@@ -1228,12 +1191,6 @@ pub fn scatter_elements<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
.multi_cartesian_product()
|
|
|
|
|
.collect::<Vec<_>>();
|
|
|
|
|
|
|
|
|
|
let mut unit = Tensor::from(vec![F::from(1)].into_iter());
|
|
|
|
|
unit.set_visibility(&crate::graph::Visibility::Fixed);
|
|
|
|
|
let unit: ValTensor<F> = unit.try_into()?;
|
|
|
|
|
region.assign(&config.custom_gates.inputs[1], &unit)?;
|
|
|
|
|
region.increment(1);
|
|
|
|
|
|
|
|
|
|
let mut output: Tensor<()> = Tensor::new(None, &output_size)?;
|
|
|
|
|
|
|
|
|
|
let mut inner_loop_function = |i: usize, region: &mut RegionCtx<'_, F>| {
|
|
|
|
|
@@ -1253,25 +1210,9 @@ pub fn scatter_elements<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
|
|
|
|
|
let mask = equals(config, region, &[index_valtensor, indices.clone()])?;
|
|
|
|
|
|
|
|
|
|
let one_minus_mask = pairwise(config, region, &[unit.clone(), mask.clone()], BaseOp::Sub)?;
|
|
|
|
|
|
|
|
|
|
let pairwise_prod = pairwise(config, region, &[src_valtensor, mask], BaseOp::Mult)?;
|
|
|
|
|
let pairwise_prod_2 = pairwise(
|
|
|
|
|
config,
|
|
|
|
|
region,
|
|
|
|
|
&[sliced_input, one_minus_mask],
|
|
|
|
|
BaseOp::Mult,
|
|
|
|
|
)?;
|
|
|
|
|
|
|
|
|
|
let res = pairwise(
|
|
|
|
|
config,
|
|
|
|
|
region,
|
|
|
|
|
&[pairwise_prod, pairwise_prod_2],
|
|
|
|
|
BaseOp::Add,
|
|
|
|
|
)?;
|
|
|
|
|
let res = iff(config, region, &[mask, src_valtensor, sliced_input])?;
|
|
|
|
|
|
|
|
|
|
let input_cartesian_coord = slice.into_iter().multi_cartesian_product();
|
|
|
|
|
|
|
|
|
|
let mutable_input_inner = input.get_inner_tensor_mut()?;
|
|
|
|
|
|
|
|
|
|
for (i, r) in res.get_inner_tensor()?.iter().enumerate() {
|
|
|
|
|
@@ -1294,7 +1235,7 @@ pub fn scatter_elements<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// sum accumulated layout
|
|
|
|
|
pub fn sum<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
pub(crate) fn sum<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
config: &BaseConfig<F>,
|
|
|
|
|
region: &mut RegionCtx<F>,
|
|
|
|
|
values: &[ValTensor<F>; 1],
|
|
|
|
|
@@ -1318,7 +1259,7 @@ pub fn sum<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
|
|
|
|
|
// if empty return a const
|
|
|
|
|
if values[0].is_empty() {
|
|
|
|
|
return Ok(Tensor::from([ValType::Constant(F::ZERO)].into_iter()).into());
|
|
|
|
|
return Ok(create_zero_tensor(1));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
let block_width = config.custom_gates.output.num_inner_cols();
|
|
|
|
|
@@ -1377,7 +1318,7 @@ pub fn sum<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// product accumulated layout
|
|
|
|
|
pub fn prod<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
pub(crate) fn prod<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
config: &BaseConfig<F>,
|
|
|
|
|
region: &mut RegionCtx<F>,
|
|
|
|
|
values: &[ValTensor<F>; 1],
|
|
|
|
|
@@ -1393,7 +1334,7 @@ pub fn prod<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
trace!("finding const zero indices took: {:?}", elapsed);
|
|
|
|
|
// if empty return a const
|
|
|
|
|
if !removal_indices.is_empty() {
|
|
|
|
|
return Ok(Tensor::from([ValType::Constant(F::ZERO)].into_iter()).into());
|
|
|
|
|
return Ok(create_zero_tensor(1));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
let block_width = config.custom_gates.output.num_inner_cols();
|
|
|
|
|
@@ -1518,7 +1459,7 @@ fn axes_wise_op<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Sum accumulated layout
|
|
|
|
|
pub fn prod_axes<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
pub(crate) fn prod_axes<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
config: &BaseConfig<F>,
|
|
|
|
|
region: &mut RegionCtx<F>,
|
|
|
|
|
values: &[ValTensor<F>; 1],
|
|
|
|
|
@@ -1529,7 +1470,7 @@ pub fn prod_axes<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Sum accumulated layout
|
|
|
|
|
pub fn sum_axes<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
pub(crate) fn sum_axes<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
config: &BaseConfig<F>,
|
|
|
|
|
region: &mut RegionCtx<F>,
|
|
|
|
|
values: &[ValTensor<F>; 1],
|
|
|
|
|
@@ -1540,7 +1481,7 @@ pub fn sum_axes<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// argmax layout
|
|
|
|
|
pub fn argmax_axes<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
pub(crate) fn argmax_axes<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
config: &BaseConfig<F>,
|
|
|
|
|
region: &mut RegionCtx<F>,
|
|
|
|
|
values: &[ValTensor<F>; 1],
|
|
|
|
|
@@ -1564,7 +1505,7 @@ pub fn argmax_axes<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Max accumulated layout
|
|
|
|
|
pub fn max_axes<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
pub(crate) fn max_axes<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
config: &BaseConfig<F>,
|
|
|
|
|
region: &mut RegionCtx<F>,
|
|
|
|
|
values: &[ValTensor<F>; 1],
|
|
|
|
|
@@ -1576,7 +1517,7 @@ pub fn max_axes<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Argmin layout
|
|
|
|
|
pub fn argmin_axes<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
pub(crate) fn argmin_axes<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
config: &BaseConfig<F>,
|
|
|
|
|
region: &mut RegionCtx<F>,
|
|
|
|
|
values: &[ValTensor<F>; 1],
|
|
|
|
|
@@ -1600,7 +1541,7 @@ pub fn argmin_axes<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Min accumulated layout
|
|
|
|
|
pub fn min_axes<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
pub(crate) fn min_axes<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
config: &BaseConfig<F>,
|
|
|
|
|
region: &mut RegionCtx<F>,
|
|
|
|
|
values: &[ValTensor<F>; 1],
|
|
|
|
|
@@ -1612,7 +1553,7 @@ pub fn min_axes<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Pairwise (elementwise) op layout
|
|
|
|
|
pub fn pairwise<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
pub(crate) fn pairwise<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
config: &BaseConfig<F>,
|
|
|
|
|
region: &mut RegionCtx<F>,
|
|
|
|
|
values: &[ValTensor<F>; 2],
|
|
|
|
|
@@ -1767,7 +1708,7 @@ pub fn pairwise<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// expand the tensor to the given shape
|
|
|
|
|
pub fn expand<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
pub(crate) fn expand<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
config: &BaseConfig<F>,
|
|
|
|
|
region: &mut RegionCtx<F>,
|
|
|
|
|
values: &[ValTensor<F>; 1],
|
|
|
|
|
@@ -1780,7 +1721,7 @@ pub fn expand<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
///
|
|
|
|
|
pub fn greater<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
pub(crate) fn greater<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
config: &BaseConfig<F>,
|
|
|
|
|
region: &mut RegionCtx<F>,
|
|
|
|
|
values: &[ValTensor<F>; 2],
|
|
|
|
|
@@ -1803,7 +1744,7 @@ pub fn greater<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
///
|
|
|
|
|
pub fn greater_equal<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
pub(crate) fn greater_equal<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
config: &BaseConfig<F>,
|
|
|
|
|
region: &mut RegionCtx<F>,
|
|
|
|
|
values: &[ValTensor<F>; 2],
|
|
|
|
|
@@ -1826,7 +1767,7 @@ pub fn greater_equal<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
///
|
|
|
|
|
pub fn less<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
pub(crate) fn less<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
config: &BaseConfig<F>,
|
|
|
|
|
region: &mut RegionCtx<F>,
|
|
|
|
|
values: &[ValTensor<F>; 2],
|
|
|
|
|
@@ -1836,7 +1777,7 @@ pub fn less<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
///
|
|
|
|
|
pub fn less_equal<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
pub(crate) fn less_equal<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
config: &BaseConfig<F>,
|
|
|
|
|
region: &mut RegionCtx<F>,
|
|
|
|
|
values: &[ValTensor<F>; 2],
|
|
|
|
|
@@ -1846,7 +1787,7 @@ pub fn less_equal<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// And boolean operation
|
|
|
|
|
pub fn and<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
pub(crate) fn and<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
config: &BaseConfig<F>,
|
|
|
|
|
region: &mut RegionCtx<F>,
|
|
|
|
|
values: &[ValTensor<F>; 2],
|
|
|
|
|
@@ -1857,7 +1798,7 @@ pub fn and<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Or boolean operation
|
|
|
|
|
pub fn or<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
pub(crate) fn or<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
config: &BaseConfig<F>,
|
|
|
|
|
region: &mut RegionCtx<F>,
|
|
|
|
|
values: &[ValTensor<F>; 2],
|
|
|
|
|
@@ -1873,7 +1814,7 @@ pub fn or<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Equality boolean operation
|
|
|
|
|
pub fn equals<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
pub(crate) fn equals<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
config: &BaseConfig<F>,
|
|
|
|
|
region: &mut RegionCtx<F>,
|
|
|
|
|
values: &[ValTensor<F>; 2],
|
|
|
|
|
@@ -1883,7 +1824,7 @@ pub fn equals<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Equality boolean operation
|
|
|
|
|
pub fn equals_zero<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
pub(crate) fn equals_zero<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
config: &BaseConfig<F>,
|
|
|
|
|
region: &mut RegionCtx<F>,
|
|
|
|
|
values: &[ValTensor<F>; 1],
|
|
|
|
|
@@ -1898,14 +1839,12 @@ pub fn equals_zero<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
)?;
|
|
|
|
|
|
|
|
|
|
// constant of 1
|
|
|
|
|
let mut ones = Tensor::from(vec![ValType::Constant(F::from(1))].into_iter());
|
|
|
|
|
ones.set_visibility(&crate::graph::Visibility::Fixed);
|
|
|
|
|
|
|
|
|
|
let ones = create_unit_tensor(1);
|
|
|
|
|
// subtract
|
|
|
|
|
let output = pairwise(
|
|
|
|
|
config,
|
|
|
|
|
region,
|
|
|
|
|
&[ones.into(), product_values_and_invert],
|
|
|
|
|
&[ones, product_values_and_invert],
|
|
|
|
|
BaseOp::Sub,
|
|
|
|
|
)?;
|
|
|
|
|
|
|
|
|
|
@@ -1918,7 +1857,7 @@ pub fn equals_zero<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Xor boolean operation
|
|
|
|
|
pub fn xor<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
pub(crate) fn xor<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
config: &BaseConfig<F>,
|
|
|
|
|
region: &mut RegionCtx<F>,
|
|
|
|
|
values: &[ValTensor<F>; 2],
|
|
|
|
|
@@ -1944,21 +1883,15 @@ pub fn xor<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Not boolean operation
|
|
|
|
|
pub fn not<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
pub(crate) fn not<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
config: &BaseConfig<F>,
|
|
|
|
|
region: &mut RegionCtx<F>,
|
|
|
|
|
values: &[ValTensor<F>; 1],
|
|
|
|
|
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
|
|
|
|
let mask = values[0].clone();
|
|
|
|
|
|
|
|
|
|
let unit: ValTensor<F> = Tensor::from(
|
|
|
|
|
vec![region.assign_constant(&config.custom_gates.inputs[0], F::from(1))?].into_iter(),
|
|
|
|
|
)
|
|
|
|
|
.into();
|
|
|
|
|
|
|
|
|
|
// to leverage sparsity we don't assign this guy
|
|
|
|
|
let nil: ValTensor<F> = Tensor::from(vec![ValType::Constant(F::from(0))].into_iter()).into();
|
|
|
|
|
region.next();
|
|
|
|
|
let unit = create_unit_tensor(1);
|
|
|
|
|
let nil = create_zero_tensor(1);
|
|
|
|
|
|
|
|
|
|
let res = iff(config, region, &[mask, nil, unit])?;
|
|
|
|
|
|
|
|
|
|
@@ -1966,7 +1899,7 @@ pub fn not<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Iff
|
|
|
|
|
pub fn iff<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
pub(crate) fn iff<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
config: &BaseConfig<F>,
|
|
|
|
|
region: &mut RegionCtx<F>,
|
|
|
|
|
values: &[ValTensor<F>; 3],
|
|
|
|
|
@@ -1974,11 +1907,7 @@ pub fn iff<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
// if mask > 0 then output a else output b
|
|
|
|
|
let (mask, a, b) = (&values[0], &values[1], &values[2]);
|
|
|
|
|
|
|
|
|
|
let unit: ValTensor<F> = Tensor::from(
|
|
|
|
|
vec![region.assign_constant(&config.custom_gates.inputs[0], F::from(1))?].into_iter(),
|
|
|
|
|
)
|
|
|
|
|
.into();
|
|
|
|
|
|
|
|
|
|
let unit = create_unit_tensor(1);
|
|
|
|
|
// make sure mask is boolean
|
|
|
|
|
let assigned_mask = boolean_identity(config, region, &[mask.clone()], true)?;
|
|
|
|
|
|
|
|
|
|
@@ -1994,23 +1923,17 @@ pub fn iff<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Negation operation accumulated layout
|
|
|
|
|
pub fn neg<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
pub(crate) fn neg<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
config: &BaseConfig<F>,
|
|
|
|
|
region: &mut RegionCtx<F>,
|
|
|
|
|
values: &[ValTensor<F>; 1],
|
|
|
|
|
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
|
|
|
|
let mut nil = Tensor::from(vec![ValType::Constant(F::from(0))].into_iter());
|
|
|
|
|
nil.set_visibility(&crate::graph::Visibility::Fixed);
|
|
|
|
|
pairwise(
|
|
|
|
|
config,
|
|
|
|
|
region,
|
|
|
|
|
&[nil.into(), values[0].clone()],
|
|
|
|
|
BaseOp::Sub,
|
|
|
|
|
)
|
|
|
|
|
let nil = create_zero_tensor(1);
|
|
|
|
|
pairwise(config, region, &[nil, values[0].clone()], BaseOp::Sub)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Sumpool accumulated layout
|
|
|
|
|
pub fn sumpool<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
pub(crate) fn sumpool<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
config: &BaseConfig<F>,
|
|
|
|
|
region: &mut RegionCtx<F>,
|
|
|
|
|
values: &[ValTensor<F>],
|
|
|
|
|
@@ -2022,11 +1945,10 @@ pub fn sumpool<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
let batch_size = values[0].dims()[0];
|
|
|
|
|
let image_channels = values[0].dims()[1];
|
|
|
|
|
|
|
|
|
|
let unit = region.assign_constant(&config.custom_gates.inputs[1], F::from(1))?;
|
|
|
|
|
region.next();
|
|
|
|
|
|
|
|
|
|
let mut kernel = Tensor::from(0..kernel_shape.0 * kernel_shape.1).map(|_| unit.clone());
|
|
|
|
|
let mut kernel = create_unit_tensor(kernel_shape.0 * kernel_shape.1);
|
|
|
|
|
kernel.reshape(&[1, 1, kernel_shape.0, kernel_shape.1])?;
|
|
|
|
|
let kernel = region.assign(&config.custom_gates.inputs[1], &kernel)?;
|
|
|
|
|
region.increment(kernel.len());
|
|
|
|
|
|
|
|
|
|
let cartesian_coord = [(0..batch_size), (0..image_channels)]
|
|
|
|
|
.iter()
|
|
|
|
|
@@ -2044,7 +1966,7 @@ pub fn sumpool<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
let output = conv(
|
|
|
|
|
config,
|
|
|
|
|
region,
|
|
|
|
|
&[input, kernel.clone().into()],
|
|
|
|
|
&[input, kernel.clone()],
|
|
|
|
|
padding,
|
|
|
|
|
stride,
|
|
|
|
|
)?;
|
|
|
|
|
@@ -2071,7 +1993,7 @@ pub fn sumpool<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Convolution accumulated layout
|
|
|
|
|
pub fn max_pool2d<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
pub(crate) fn max_pool2d<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
config: &BaseConfig<F>,
|
|
|
|
|
region: &mut RegionCtx<F>,
|
|
|
|
|
values: &[ValTensor<F>; 1],
|
|
|
|
|
@@ -2135,7 +2057,9 @@ pub fn max_pool2d<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// DeConvolution accumulated layout
|
|
|
|
|
pub fn deconv<F: PrimeField + TensorType + PartialOrd + std::marker::Send + std::marker::Sync>(
|
|
|
|
|
pub(crate) fn deconv<
|
|
|
|
|
F: PrimeField + TensorType + PartialOrd + std::marker::Send + std::marker::Sync,
|
|
|
|
|
>(
|
|
|
|
|
config: &BaseConfig<F>,
|
|
|
|
|
region: &mut RegionCtx<F>,
|
|
|
|
|
inputs: &[ValTensor<F>],
|
|
|
|
|
@@ -2159,8 +2083,6 @@ pub fn deconv<F: PrimeField + TensorType + PartialOrd + std::marker::Send + std:
|
|
|
|
|
let (kernel_height, kernel_width) = (kernel.dims()[2], kernel.dims()[3]);
|
|
|
|
|
|
|
|
|
|
let null_val = ValType::Constant(F::ZERO);
|
|
|
|
|
// region.assign_constant(&config.custom_gates.inputs[1], F::from(0))?;
|
|
|
|
|
// region.next();
|
|
|
|
|
|
|
|
|
|
let mut expanded_image = image.clone();
|
|
|
|
|
expanded_image.intercalate_values(null_val.clone(), stride.0, 2)?;
|
|
|
|
|
@@ -2228,7 +2150,9 @@ pub fn deconv<F: PrimeField + TensorType + PartialOrd + std::marker::Send + std:
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Convolution accumulated layout
|
|
|
|
|
pub fn conv<F: PrimeField + TensorType + PartialOrd + std::marker::Send + std::marker::Sync>(
|
|
|
|
|
pub(crate) fn conv<
|
|
|
|
|
F: PrimeField + TensorType + PartialOrd + std::marker::Send + std::marker::Sync,
|
|
|
|
|
>(
|
|
|
|
|
config: &BaseConfig<F>,
|
|
|
|
|
region: &mut RegionCtx<F>,
|
|
|
|
|
values: &[ValTensor<F>],
|
|
|
|
|
@@ -2408,7 +2332,7 @@ pub fn conv<F: PrimeField + TensorType + PartialOrd + std::marker::Send + std::m
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Power accumulated layout
|
|
|
|
|
pub fn pow<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
pub(crate) fn pow<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
config: &BaseConfig<F>,
|
|
|
|
|
region: &mut RegionCtx<F>,
|
|
|
|
|
values: &[ValTensor<F>; 1],
|
|
|
|
|
@@ -2424,7 +2348,7 @@ pub fn pow<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Rescaled op accumulated layout
|
|
|
|
|
pub fn rescale<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
pub(crate) fn rescale<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
config: &BaseConfig<F>,
|
|
|
|
|
region: &mut RegionCtx<F>,
|
|
|
|
|
values: &[ValTensor<F>],
|
|
|
|
|
@@ -2437,8 +2361,7 @@ pub fn rescale<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
let multiplier: ValTensor<F> =
|
|
|
|
|
Tensor::from(vec![ValType::Constant(F::from(scales[i].1 as u64))].into_iter()).into();
|
|
|
|
|
let multiplier = create_constant_tensor(F::from(scales[i].1 as u64), 1);
|
|
|
|
|
let scaled_input = pairwise(config, region, &[ri.clone(), multiplier], BaseOp::Mult)?;
|
|
|
|
|
rescaled_inputs.push(scaled_input);
|
|
|
|
|
}
|
|
|
|
|
@@ -2446,44 +2369,8 @@ pub fn rescale<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
Ok(rescaled_inputs)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Pack accumulated layout
|
|
|
|
|
pub fn pack<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
config: &BaseConfig<F>,
|
|
|
|
|
region: &mut RegionCtx<F>,
|
|
|
|
|
values: &[ValTensor<F>; 1],
|
|
|
|
|
base: u32,
|
|
|
|
|
scale: u32,
|
|
|
|
|
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
|
|
|
|
let mut t = values[0].clone();
|
|
|
|
|
t.flatten();
|
|
|
|
|
|
|
|
|
|
// these unwraps should never ever fail if the Tensortypes are correctly implemented
|
|
|
|
|
// if anything we want these to hard fail if not implemented
|
|
|
|
|
let mut base_t = <F as TensorType>::zero().ok_or(TensorError::FeltError)?;
|
|
|
|
|
for _ in 0..base {
|
|
|
|
|
base_t += <F as TensorType>::one().ok_or(TensorError::FeltError)?;
|
|
|
|
|
}
|
|
|
|
|
let mut accum_base = vec![];
|
|
|
|
|
let base_tensor = Tensor::new(Some(&[base_t]), &[1])?;
|
|
|
|
|
for i in 0..t.dims().iter().product::<usize>() {
|
|
|
|
|
accum_base.push(Value::known(base_tensor.pow((i as u32) * (scale + 1))?[0]));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
let base_tensor = Tensor::new(Some(&accum_base), &[accum_base.len()])?;
|
|
|
|
|
let base_prod = pairwise(
|
|
|
|
|
config,
|
|
|
|
|
region,
|
|
|
|
|
&[t.clone(), base_tensor.into()],
|
|
|
|
|
BaseOp::Mult,
|
|
|
|
|
)?;
|
|
|
|
|
|
|
|
|
|
let res = sum(config, region, &[base_prod])?;
|
|
|
|
|
|
|
|
|
|
Ok(res)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Dummy (no contraints) reshape layout
|
|
|
|
|
pub fn reshape<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
pub(crate) fn reshape<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
values: &[ValTensor<F>; 1],
|
|
|
|
|
new_dims: &[usize],
|
|
|
|
|
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
|
|
|
|
@@ -2493,7 +2380,7 @@ pub fn reshape<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Dummy (no contraints) move_axis layout
|
|
|
|
|
pub fn move_axis<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
pub(crate) fn move_axis<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
values: &[ValTensor<F>; 1],
|
|
|
|
|
source: usize,
|
|
|
|
|
destination: usize,
|
|
|
|
|
@@ -2504,7 +2391,7 @@ pub fn move_axis<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// resize layout
|
|
|
|
|
pub fn resize<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
pub(crate) fn resize<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
config: &BaseConfig<F>,
|
|
|
|
|
region: &mut RegionCtx<F>,
|
|
|
|
|
values: &[ValTensor<F>; 1],
|
|
|
|
|
@@ -2518,7 +2405,7 @@ pub fn resize<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Slice layout
|
|
|
|
|
pub fn slice<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
pub(crate) fn slice<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
config: &BaseConfig<F>,
|
|
|
|
|
region: &mut RegionCtx<F>,
|
|
|
|
|
values: &[ValTensor<F>; 1],
|
|
|
|
|
@@ -2535,7 +2422,7 @@ pub fn slice<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Concat layout
|
|
|
|
|
pub fn concat<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
pub(crate) fn concat<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
values: &[ValTensor<F>],
|
|
|
|
|
axis: &usize,
|
|
|
|
|
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
|
|
|
|
@@ -2547,7 +2434,7 @@ pub fn concat<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Identity constraint. Usually used to constrain an instance column to an advice so the returned cells / values can be operated upon.
|
|
|
|
|
pub fn identity<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
pub(crate) fn identity<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
config: &BaseConfig<F>,
|
|
|
|
|
region: &mut RegionCtx<F>,
|
|
|
|
|
values: &[ValTensor<F>; 1],
|
|
|
|
|
@@ -2562,7 +2449,7 @@ pub fn identity<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// is zero identity constraint.
|
|
|
|
|
pub fn is_zero_identity<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
pub(crate) fn is_zero_identity<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
config: &BaseConfig<F>,
|
|
|
|
|
region: &mut RegionCtx<F>,
|
|
|
|
|
values: &[ValTensor<F>; 1],
|
|
|
|
|
@@ -2594,7 +2481,7 @@ pub fn is_zero_identity<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Boolean identity constraint. Usually used to constrain an instance column to an advice so the returned cells / values can be operated upon.
|
|
|
|
|
pub fn boolean_identity<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
pub(crate) fn boolean_identity<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
config: &BaseConfig<F>,
|
|
|
|
|
region: &mut RegionCtx<F>,
|
|
|
|
|
values: &[ValTensor<F>; 1],
|
|
|
|
|
@@ -2630,7 +2517,7 @@ pub fn boolean_identity<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Downsample layout
|
|
|
|
|
pub fn downsample<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
pub(crate) fn downsample<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
config: &BaseConfig<F>,
|
|
|
|
|
region: &mut RegionCtx<F>,
|
|
|
|
|
values: &[ValTensor<F>; 1],
|
|
|
|
|
@@ -2647,12 +2534,17 @@ pub fn downsample<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// layout for enforcing two sets of cells to be equal
|
|
|
|
|
pub fn enforce_equality<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
pub(crate) fn enforce_equality<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
config: &BaseConfig<F>,
|
|
|
|
|
region: &mut RegionCtx<F>,
|
|
|
|
|
values: &[ValTensor<F>; 2],
|
|
|
|
|
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
|
|
|
|
// assert of same len
|
|
|
|
|
if values[0].len() != values[1].len() {
|
|
|
|
|
return Err(Box::new(TensorError::DimMismatch(
|
|
|
|
|
"enforce_equality".to_string(),
|
|
|
|
|
)));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// assigns the instance to the advice.
|
|
|
|
|
let input = region.assign(&config.custom_gates.inputs[1], &values[0])?;
|
|
|
|
|
@@ -2668,7 +2560,7 @@ pub fn enforce_equality<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// layout for range check.
|
|
|
|
|
pub fn range_check<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
pub(crate) fn range_check<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
config: &BaseConfig<F>,
|
|
|
|
|
region: &mut RegionCtx<F>,
|
|
|
|
|
values: &[ValTensor<F>; 1],
|
|
|
|
|
@@ -2749,7 +2641,7 @@ pub fn range_check<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// layout for nonlinearity check.
|
|
|
|
|
pub fn nonlinearity<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
pub(crate) fn nonlinearity<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
config: &BaseConfig<F>,
|
|
|
|
|
region: &mut RegionCtx<F>,
|
|
|
|
|
values: &[ValTensor<F>; 1],
|
|
|
|
|
@@ -2851,7 +2743,7 @@ pub fn nonlinearity<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Argmax
|
|
|
|
|
pub fn argmax<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
pub(crate) fn argmax<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
config: &BaseConfig<F>,
|
|
|
|
|
region: &mut RegionCtx<F>,
|
|
|
|
|
values: &[ValTensor<F>; 1],
|
|
|
|
|
@@ -2889,7 +2781,7 @@ pub fn argmax<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Argmin
|
|
|
|
|
pub fn argmin<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
pub(crate) fn argmin<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
config: &BaseConfig<F>,
|
|
|
|
|
region: &mut RegionCtx<F>,
|
|
|
|
|
values: &[ValTensor<F>; 1],
|
|
|
|
|
@@ -2927,7 +2819,7 @@ pub fn argmin<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// max layout
|
|
|
|
|
pub fn max<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
pub(crate) fn max<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
config: &BaseConfig<F>,
|
|
|
|
|
region: &mut RegionCtx<F>,
|
|
|
|
|
values: &[ValTensor<F>; 1],
|
|
|
|
|
@@ -2937,7 +2829,7 @@ pub fn max<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// min layout
|
|
|
|
|
pub fn min<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
pub(crate) fn min<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
config: &BaseConfig<F>,
|
|
|
|
|
region: &mut RegionCtx<F>,
|
|
|
|
|
values: &[ValTensor<F>; 1],
|
|
|
|
|
@@ -3048,7 +2940,7 @@ fn multi_dim_axes_op<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// softmax layout
|
|
|
|
|
pub fn softmax_axes<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
pub(crate) fn softmax_axes<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
config: &BaseConfig<F>,
|
|
|
|
|
region: &mut RegionCtx<F>,
|
|
|
|
|
values: &[ValTensor<F>; 1],
|
|
|
|
|
@@ -3068,7 +2960,7 @@ pub fn softmax_axes<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// softmax func
|
|
|
|
|
pub fn softmax<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
pub(crate) fn softmax<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
config: &BaseConfig<F>,
|
|
|
|
|
region: &mut RegionCtx<F>,
|
|
|
|
|
values: &[ValTensor<F>; 1],
|
|
|
|
|
@@ -3093,7 +2985,7 @@ pub fn softmax<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
/// Checks that the percent error between the expected public output and the actual output value
|
|
|
|
|
/// is within the percent error expressed by the `tol` input, where `tol == 1.0` means the percent
|
|
|
|
|
/// error tolerance is 1 percent.
|
|
|
|
|
pub fn range_check_percent<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
pub(crate) fn range_check_percent<F: PrimeField + TensorType + PartialOrd>(
|
|
|
|
|
config: &BaseConfig<F>,
|
|
|
|
|
region: &mut RegionCtx<F>,
|
|
|
|
|
values: &[ValTensor<F>; 2],
|
|
|
|
|
|