mirror of
https://github.com/zkonduit/ezkl.git
synced 2026-01-13 16:27:59 -05:00
Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ea535e2ecd | ||
|
|
f8aa91ed08 | ||
|
|
a59e3780b2 |
@@ -17,7 +17,6 @@ pub enum BaseOp {
|
||||
Sub,
|
||||
SumInit,
|
||||
Sum,
|
||||
IsZero,
|
||||
IsBoolean,
|
||||
}
|
||||
|
||||
@@ -35,7 +34,6 @@ impl BaseOp {
|
||||
BaseOp::Add => a + b,
|
||||
BaseOp::Sub => a - b,
|
||||
BaseOp::Mult => a * b,
|
||||
BaseOp::IsZero => b,
|
||||
BaseOp::IsBoolean => b,
|
||||
_ => panic!("nonaccum_f called on accumulating operation"),
|
||||
}
|
||||
@@ -76,7 +74,6 @@ impl BaseOp {
|
||||
BaseOp::Mult => "MULT",
|
||||
BaseOp::Sum => "SUM",
|
||||
BaseOp::SumInit => "SUMINIT",
|
||||
BaseOp::IsZero => "ISZERO",
|
||||
BaseOp::IsBoolean => "ISBOOLEAN",
|
||||
}
|
||||
}
|
||||
@@ -93,7 +90,6 @@ impl BaseOp {
|
||||
BaseOp::Mult => (0, 1),
|
||||
BaseOp::Sum => (-1, 2),
|
||||
BaseOp::SumInit => (0, 1),
|
||||
BaseOp::IsZero => (0, 1),
|
||||
BaseOp::IsBoolean => (0, 1),
|
||||
}
|
||||
}
|
||||
@@ -110,7 +106,6 @@ impl BaseOp {
|
||||
BaseOp::Mult => 2,
|
||||
BaseOp::Sum => 1,
|
||||
BaseOp::SumInit => 1,
|
||||
BaseOp::IsZero => 0,
|
||||
BaseOp::IsBoolean => 0,
|
||||
}
|
||||
}
|
||||
@@ -127,7 +122,6 @@ impl BaseOp {
|
||||
BaseOp::SumInit => 0,
|
||||
BaseOp::CumProd => 1,
|
||||
BaseOp::CumProdInit => 0,
|
||||
BaseOp::IsZero => 0,
|
||||
BaseOp::IsBoolean => 0,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -387,7 +387,6 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
|
||||
nonaccum_selectors.insert((BaseOp::Add, i, j), meta.selector());
|
||||
nonaccum_selectors.insert((BaseOp::Sub, i, j), meta.selector());
|
||||
nonaccum_selectors.insert((BaseOp::Mult, i, j), meta.selector());
|
||||
nonaccum_selectors.insert((BaseOp::IsZero, i, j), meta.selector());
|
||||
nonaccum_selectors.insert((BaseOp::IsBoolean, i, j), meta.selector());
|
||||
}
|
||||
}
|
||||
@@ -432,12 +431,6 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
|
||||
|
||||
vec![(output.clone()) * (output.clone() - Expression::Constant(F::from(1)))]
|
||||
}
|
||||
BaseOp::IsZero => {
|
||||
let expected_output: Tensor<Expression<F>> = output
|
||||
.query_rng(meta, *block_idx, *inner_col_idx, 0, 1)
|
||||
.expect("non accum: output query failed");
|
||||
vec![expected_output[base_op.constraint_idx()].clone()]
|
||||
}
|
||||
_ => {
|
||||
let expected_output: Tensor<Expression<F>> = output
|
||||
.query_rng(meta, *block_idx, *inner_col_idx, rotation_offset, rng)
|
||||
|
||||
@@ -132,41 +132,6 @@ pub(crate) fn div<F: PrimeField + TensorType + PartialOrd>(
|
||||
Ok(claimed_output)
|
||||
}
|
||||
|
||||
fn recip_int<F: PrimeField + TensorType + PartialOrd>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
input: &[ValTensor<F>; 1],
|
||||
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
// assert is boolean
|
||||
let zero_inverse_val = tensor::ops::nonlinearities::zero_recip(1.0)[0];
|
||||
// get values where input is 0
|
||||
let zero_mask = equals_zero(config, region, input)?;
|
||||
|
||||
let zero_mask_minus_one = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[zero_mask.clone(), create_unit_tensor(1)],
|
||||
BaseOp::Sub,
|
||||
)?;
|
||||
|
||||
let zero_inverse_val = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[
|
||||
zero_mask,
|
||||
create_constant_tensor(i128_to_felt(zero_inverse_val), 1),
|
||||
],
|
||||
BaseOp::Mult,
|
||||
)?;
|
||||
|
||||
pairwise(
|
||||
config,
|
||||
region,
|
||||
&[zero_mask_minus_one, zero_inverse_val],
|
||||
BaseOp::Add,
|
||||
)
|
||||
}
|
||||
|
||||
/// recip accumulated layout
|
||||
pub(crate) fn recip<F: PrimeField + TensorType + PartialOrd>(
|
||||
config: &BaseConfig<F>,
|
||||
@@ -175,10 +140,6 @@ pub(crate) fn recip<F: PrimeField + TensorType + PartialOrd>(
|
||||
input_scale: F,
|
||||
output_scale: F,
|
||||
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
if output_scale == F::ONE || output_scale == F::ZERO {
|
||||
return recip_int(config, region, value);
|
||||
}
|
||||
|
||||
let input = value[0].clone();
|
||||
let input_dims = input.dims();
|
||||
|
||||
@@ -188,8 +149,11 @@ pub(crate) fn recip<F: PrimeField + TensorType + PartialOrd>(
|
||||
// range_check_bracket is min of input_scale * output_scale and 2^F::S - 3
|
||||
let range_check_len = std::cmp::min(integer_output_scale, 2_i128.pow(F::S - 4));
|
||||
|
||||
let input_scale_ratio =
|
||||
i128_to_felt(integer_input_scale * integer_output_scale / range_check_len);
|
||||
let input_scale_ratio = if range_check_len > 0 {
|
||||
i128_to_felt(integer_input_scale * integer_output_scale / range_check_len)
|
||||
} else {
|
||||
F::ONE
|
||||
};
|
||||
|
||||
let range_check_bracket = range_check_len / 2;
|
||||
|
||||
@@ -234,11 +198,7 @@ pub(crate) fn recip<F: PrimeField + TensorType + PartialOrd>(
|
||||
|
||||
let equal_zero_mask = equals_zero(config, region, &[input.clone()])?;
|
||||
|
||||
let equal_inverse_mask = equals(
|
||||
config,
|
||||
region,
|
||||
&[claimed_output.clone(), zero_inverse],
|
||||
)?;
|
||||
let equal_inverse_mask = equals(config, region, &[claimed_output.clone(), zero_inverse])?;
|
||||
|
||||
// assert the two masks are equal
|
||||
enforce_equality(
|
||||
@@ -249,12 +209,7 @@ pub(crate) fn recip<F: PrimeField + TensorType + PartialOrd>(
|
||||
|
||||
let unit_scale = create_constant_tensor(i128_to_felt(range_check_len), 1);
|
||||
|
||||
let unit_mask = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[equal_zero_mask, unit_scale],
|
||||
BaseOp::Mult,
|
||||
)?;
|
||||
let unit_mask = pairwise(config, region, &[equal_zero_mask, unit_scale], BaseOp::Mult)?;
|
||||
|
||||
// now add the unit mask to the rebased_div
|
||||
let rebased_offset_div = pairwise(config, region, &[rebased_div, unit_mask], BaseOp::Add)?;
|
||||
@@ -691,14 +646,13 @@ fn select<F: PrimeField + TensorType + PartialOrd>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 2],
|
||||
dim_indices: ValTensor<F>,
|
||||
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
let (mut input, index) = (values[0].clone(), values[1].clone());
|
||||
input.flatten();
|
||||
|
||||
if !(dim_indices.all_prev_assigned() || region.is_dummy()) {
|
||||
return Err("dim_indices must be assigned".into());
|
||||
}
|
||||
// these will be assigned as constants
|
||||
let dim_indices: ValTensor<F> =
|
||||
Tensor::from((0..input.len() as u64).map(|x| ValType::Constant(F::from(x)))).into();
|
||||
|
||||
let is_assigned = !input.any_unknowns()? && !index.any_unknowns()?;
|
||||
|
||||
@@ -974,91 +928,25 @@ pub(crate) fn gather<F: PrimeField + TensorType + PartialOrd>(
|
||||
values: &[ValTensor<F>; 2],
|
||||
dim: usize,
|
||||
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
let (mut input, mut index_clone) = (values[0].clone(), values[1].clone());
|
||||
let (input, mut index_clone) = (values[0].clone(), values[1].clone());
|
||||
index_clone.flatten();
|
||||
if index_clone.is_singleton() {
|
||||
index_clone.reshape(&[1])?;
|
||||
}
|
||||
|
||||
let mut assigned_len = vec![];
|
||||
if !input.all_prev_assigned() {
|
||||
input = region.assign(&config.custom_gates.inputs[0], &input)?;
|
||||
assigned_len.push(input.len());
|
||||
}
|
||||
if !index_clone.all_prev_assigned() {
|
||||
index_clone = region.assign(&config.custom_gates.inputs[1], &index_clone)?;
|
||||
assigned_len.push(index_clone.len());
|
||||
}
|
||||
|
||||
if !assigned_len.is_empty() {
|
||||
// safe to unwrap since we've just checked it has at least one element
|
||||
region.increment(*assigned_len.iter().max().unwrap());
|
||||
}
|
||||
|
||||
// Calculate the output tensor size
|
||||
let input_dims = input.dims();
|
||||
let mut output_size = input_dims.to_vec();
|
||||
|
||||
output_size[dim] = index_clone.dims()[0];
|
||||
|
||||
// these will be assigned as constants
|
||||
let mut indices = Tensor::from((0..input.dims()[dim] as u64).map(|x| F::from(x)));
|
||||
indices.set_visibility(&crate::graph::Visibility::Fixed);
|
||||
let indices = region.assign(&config.custom_gates.inputs[1], &indices.try_into()?)?;
|
||||
region.increment(indices.len());
|
||||
let linear_index =
|
||||
linearize_element_index(config, region, &[index_clone], input_dims, dim, true)?;
|
||||
|
||||
let mut iteration_dims = output_size.clone();
|
||||
iteration_dims[dim] = 1;
|
||||
let mut output = select(config, region, &[input, linear_index])?;
|
||||
|
||||
// Allocate memory for the output tensor
|
||||
let cartesian_coord = iteration_dims
|
||||
.iter()
|
||||
.map(|x| 0..*x)
|
||||
.multi_cartesian_product()
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let mut results = HashMap::new();
|
||||
|
||||
for coord in cartesian_coord {
|
||||
let mut slice = coord.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
|
||||
slice[dim] = 0..input_dims[dim];
|
||||
|
||||
let mut sliced_input = input.get_slice(&slice)?;
|
||||
sliced_input.flatten();
|
||||
|
||||
let res = select(
|
||||
config,
|
||||
region,
|
||||
&[sliced_input, index_clone.clone()],
|
||||
indices.clone(),
|
||||
)?;
|
||||
|
||||
results.insert(coord, res);
|
||||
}
|
||||
|
||||
// Allocate memory for the output tensor
|
||||
let cartesian_coord = output_size
|
||||
.iter()
|
||||
.map(|x| 0..*x)
|
||||
.multi_cartesian_product()
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let mut output = Tensor::new(None, &output_size)?.par_enum_map(|i, _: ValType<F>| {
|
||||
let coord = cartesian_coord[i].clone();
|
||||
let mut key = coord.clone();
|
||||
key[dim] = 0;
|
||||
let result = &results.get(&key).ok_or("missing result")?;
|
||||
let o = result.get_inner_tensor().map_err(|_| "missing tensor")?[coord[dim]].clone();
|
||||
Ok::<ValType<F>, region::RegionError>(o)
|
||||
})?;
|
||||
|
||||
// Reshape the output tensor
|
||||
if index_clone.is_singleton() {
|
||||
output_size.remove(dim);
|
||||
}
|
||||
output.reshape(&output_size)?;
|
||||
|
||||
Ok(output.into())
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
/// Gather accumulated layout
|
||||
@@ -1067,82 +955,173 @@ pub(crate) fn gather_elements<F: PrimeField + TensorType + PartialOrd>(
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 2],
|
||||
dim: usize,
|
||||
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
let (mut input, mut index) = (values[0].clone(), values[1].clone());
|
||||
) -> Result<(ValTensor<F>, ValTensor<F>), Box<dyn Error>> {
|
||||
let (input, index) = (values[0].clone(), values[1].clone());
|
||||
|
||||
assert_eq!(input.dims().len(), index.dims().len());
|
||||
|
||||
if !input.all_prev_assigned() {
|
||||
input = region.assign(&config.custom_gates.inputs[0], &input)?;
|
||||
}
|
||||
if !index.all_prev_assigned() {
|
||||
index = region.assign(&config.custom_gates.inputs[1], &index)?;
|
||||
}
|
||||
|
||||
region.increment(std::cmp::max(input.len(), index.len()));
|
||||
|
||||
// Calculate the output tensor size
|
||||
let input_dims = input.dims();
|
||||
let output_size = index.dims().to_vec();
|
||||
|
||||
// these will be assigned as constants
|
||||
let mut indices = Tensor::from((0..input_dims[dim] as u64).map(|x| F::from(x)));
|
||||
indices.set_visibility(&crate::graph::Visibility::Fixed);
|
||||
let indices = region.assign(&config.custom_gates.inputs[1], &indices.try_into()?)?;
|
||||
region.increment(indices.len());
|
||||
let linear_index = linearize_element_index(config, region, &[index], input.dims(), dim, false)?;
|
||||
|
||||
let mut iteration_dims = output_size.clone();
|
||||
iteration_dims[dim] = 1;
|
||||
let mut output = select(config, region, &[input, linear_index.clone()])?;
|
||||
|
||||
output.reshape(&output_size)?;
|
||||
|
||||
Ok((output, linear_index))
|
||||
}
|
||||
|
||||
/// Takes a tensor representing a multi-dimensional index and returns a tensor representing the linearized index.
|
||||
/// The linearized index is the index of the element in the flattened tensor.
|
||||
/// FOr instance if the dims is [3,5,2], the linearized index of [2] at dim 1 is 2*5 + 3 = 13
|
||||
pub(crate) fn linearize_element_index<F: PrimeField + TensorType + PartialOrd>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
dims: &[usize],
|
||||
dim: usize,
|
||||
is_flat_index: bool,
|
||||
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
let index = values[0].clone();
|
||||
if !is_flat_index {
|
||||
assert_eq!(index.dims().len(), dims.len());
|
||||
// if the index is already flat, return it
|
||||
if index.dims().len() == 1 {
|
||||
return Ok(index);
|
||||
}
|
||||
}
|
||||
|
||||
let dim_multiplier: Tensor<usize> = Tensor::new(None, &[dims.len()])?;
|
||||
|
||||
let dim_multiplier: Tensor<F> = dim_multiplier.par_enum_map(|i, _| {
|
||||
let mut res = 1;
|
||||
for dim in dims.iter().skip(i + 1) {
|
||||
res *= dim;
|
||||
}
|
||||
|
||||
Ok::<_, region::RegionError>(F::from(res as u64))
|
||||
})?;
|
||||
|
||||
let iteration_dims = if is_flat_index {
|
||||
let mut dims = dims.to_vec();
|
||||
dims[dim] = index.len();
|
||||
dims
|
||||
} else {
|
||||
index.dims().to_vec()
|
||||
};
|
||||
|
||||
// Allocate memory for the output tensor
|
||||
let cartesian_coord = iteration_dims
|
||||
.iter()
|
||||
.map(|x| 0..*x)
|
||||
.multi_cartesian_product()
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let mut results = HashMap::new();
|
||||
let val_dim_multiplier: ValTensor<F> = dim_multiplier
|
||||
.get_slice(&[dim..dim + 1])?
|
||||
.map(|x| ValType::Constant(x))
|
||||
.into();
|
||||
|
||||
for coord in cartesian_coord {
|
||||
let mut slice = coord.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
|
||||
slice[dim] = 0..input_dims[dim];
|
||||
let mut output = Tensor::new(None, &[cartesian_coord.len()])?;
|
||||
|
||||
let mut sliced_input = input.get_slice(&slice)?;
|
||||
sliced_input.flatten();
|
||||
let inner_loop_function = |i: usize, region: &mut RegionCtx<'_, F>| {
|
||||
let coord = cartesian_coord[i].clone();
|
||||
let slice: Vec<Range<usize>> = if is_flat_index {
|
||||
coord[dim..dim + 1].iter().map(|x| *x..*x + 1).collect()
|
||||
} else {
|
||||
coord.iter().map(|x| *x..*x + 1).collect::<Vec<_>>()
|
||||
};
|
||||
|
||||
slice[dim] = 0..output_size[dim];
|
||||
let mut sliced_index = index.get_slice(&slice)?;
|
||||
sliced_index.flatten();
|
||||
let index_val = index.get_slice(&slice)?;
|
||||
|
||||
let res = select(
|
||||
let mut const_offset = F::ZERO;
|
||||
for i in 0..dims.len() {
|
||||
if i != dim {
|
||||
const_offset += F::from(coord[i] as u64) * dim_multiplier[i];
|
||||
}
|
||||
}
|
||||
let const_offset = create_constant_tensor(const_offset, 1);
|
||||
|
||||
let res = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[sliced_input, sliced_index],
|
||||
indices.clone(),
|
||||
&[index_val, val_dim_multiplier.clone()],
|
||||
BaseOp::Mult,
|
||||
)?;
|
||||
|
||||
results.insert(coord, res);
|
||||
}
|
||||
let res = pairwise(config, region, &[res, const_offset], BaseOp::Add)?;
|
||||
|
||||
// Allocate memory for the output tensor
|
||||
let cartesian_coord = output_size
|
||||
.iter()
|
||||
.map(|x| 0..*x)
|
||||
.multi_cartesian_product()
|
||||
.collect::<Vec<_>>();
|
||||
Ok(res.get_inner_tensor()?[0].clone())
|
||||
};
|
||||
|
||||
let output = Tensor::new(None, &output_size)?.par_enum_map(|i, _: ValType<F>| {
|
||||
let coord = cartesian_coord[i].clone();
|
||||
let mut key = coord.clone();
|
||||
key[dim] = 0;
|
||||
let result = &results.get(&key).ok_or("missing result")?;
|
||||
let o = result.get_inner_tensor().map_err(|_| "missing tensor")?[coord[dim]].clone();
|
||||
Ok::<ValType<F>, region::RegionError>(o)
|
||||
})?;
|
||||
region.apply_in_loop(&mut output, inner_loop_function)?;
|
||||
|
||||
Ok(output.into())
|
||||
}
|
||||
|
||||
pub(crate) fn get_missing_set_elements<F: PrimeField + TensorType + PartialOrd>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 2],
|
||||
ordered: bool,
|
||||
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
let (mut input, fullset) = (values[0].clone(), values[1].clone());
|
||||
let set_len = fullset.len();
|
||||
input.flatten();
|
||||
|
||||
let is_assigned = !input.any_unknowns()? && !fullset.any_unknowns()?;
|
||||
|
||||
let mut claimed_output: ValTensor<F> = if is_assigned {
|
||||
let input_evals = input.get_int_evals()?;
|
||||
let mut fullset_evals = fullset.get_int_evals()?.into_iter().collect::<Vec<_>>();
|
||||
|
||||
// get the difference between the two vectors
|
||||
for eval in input_evals.iter() {
|
||||
// delete first occurence of that value
|
||||
if let Some(pos) = fullset_evals.iter().position(|x| x == eval) {
|
||||
fullset_evals.remove(pos);
|
||||
}
|
||||
}
|
||||
|
||||
// if fullset + input is the same length, then input is a subset of fullset, else randomly delete elements, this is a patch for
|
||||
// the fact that we can't have a tensor of unknowns when using constant during gen-settings
|
||||
if fullset_evals.len() != set_len - input.len() {
|
||||
fullset_evals.truncate(set_len - input.len());
|
||||
}
|
||||
|
||||
fullset_evals
|
||||
.iter()
|
||||
.map(|x| Value::known(i128_to_felt(*x)))
|
||||
.collect::<Tensor<Value<F>>>()
|
||||
.into()
|
||||
} else {
|
||||
let dim = fullset.len() - input.len();
|
||||
Tensor::new(Some(&vec![Value::<F>::unknown(); dim]), &[dim])?.into()
|
||||
};
|
||||
|
||||
// assign the claimed output
|
||||
claimed_output = region.assign(&config.custom_gates.output, &claimed_output)?;
|
||||
|
||||
// input and claimed output should be the shuffles of fullset
|
||||
// concatentate input and claimed output
|
||||
let input_and_claimed_output = input.concat(claimed_output.clone())?;
|
||||
|
||||
// assert that this is a permutation/shuffle
|
||||
shuffles(
|
||||
config,
|
||||
region,
|
||||
&[input_and_claimed_output.clone()],
|
||||
&[fullset.clone()],
|
||||
)?;
|
||||
|
||||
if ordered {
|
||||
// assert that the claimed output is sorted
|
||||
claimed_output = _sort_ascending(config, region, &[claimed_output])?;
|
||||
}
|
||||
|
||||
Ok(claimed_output)
|
||||
}
|
||||
|
||||
/// Gather accumulated layout
|
||||
pub(crate) fn scatter_elements<F: PrimeField + TensorType + PartialOrd>(
|
||||
config: &BaseConfig<F>,
|
||||
@@ -1150,88 +1129,81 @@ pub(crate) fn scatter_elements<F: PrimeField + TensorType + PartialOrd>(
|
||||
values: &[ValTensor<F>; 3],
|
||||
dim: usize,
|
||||
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
let (mut input, mut index, mut src) = (values[0].clone(), values[1].clone(), values[2].clone());
|
||||
let (input, mut index, src) = (values[0].clone(), values[1].clone(), values[2].clone());
|
||||
|
||||
assert_eq!(input.dims().len(), index.dims().len());
|
||||
|
||||
let mut assigned_len = vec![];
|
||||
let input_dims = input.dims();
|
||||
|
||||
if !input.all_prev_assigned() {
|
||||
input = region.assign(&config.custom_gates.inputs[0], &input)?;
|
||||
assigned_len.push(input.len());
|
||||
}
|
||||
if !index.all_prev_assigned() {
|
||||
index = region.assign(&config.custom_gates.inputs[1], &index)?;
|
||||
assigned_len.push(index.len());
|
||||
}
|
||||
if !src.all_prev_assigned() {
|
||||
src = region.assign(&config.custom_gates.output, &src)?;
|
||||
assigned_len.push(src.len());
|
||||
region.increment(index.len());
|
||||
}
|
||||
|
||||
if !assigned_len.is_empty() {
|
||||
// safe to unwrap since we've just checked it has at least one element
|
||||
region.increment(*assigned_len.iter().max().unwrap());
|
||||
}
|
||||
let is_assigned = !input.any_unknowns()? && !index.any_unknowns()? && !src.any_unknowns()?;
|
||||
|
||||
// Calculate the output tensor size
|
||||
let input_dim = input.dims()[dim];
|
||||
let output_size = index.dims().to_vec();
|
||||
let claimed_output: ValTensor<F> = if is_assigned {
|
||||
let input_inner = input.get_int_evals()?;
|
||||
let index_inner = index.get_int_evals()?.map(|x| x as usize);
|
||||
let src_inner = src.get_int_evals()?;
|
||||
|
||||
// these will be assigned as constants
|
||||
let mut indices = Tensor::from((0..input_dim as u64).map(|x| F::from(x)));
|
||||
indices.set_visibility(&crate::graph::Visibility::Fixed);
|
||||
let indices = region.assign(&config.custom_gates.inputs[1], &indices.try_into()?)?;
|
||||
region.increment(indices.len());
|
||||
let res = tensor::ops::scatter(&input_inner, &index_inner, &src_inner, dim)?;
|
||||
|
||||
// Allocate memory for the output tensor
|
||||
let cartesian_coord = output_size
|
||||
.iter()
|
||||
.map(|x| 0..*x)
|
||||
.multi_cartesian_product()
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let mut output: Tensor<()> = Tensor::new(None, &output_size)?;
|
||||
|
||||
let mut inner_loop_function = |i: usize, region: &mut RegionCtx<'_, F>| {
|
||||
let coord = cartesian_coord[i].clone();
|
||||
let index_val = index.get_inner_tensor()?.get(&coord);
|
||||
|
||||
let src_val = src.get_inner_tensor()?.get(&coord);
|
||||
let src_valtensor: ValTensor<F> = Tensor::from([src_val.clone()].into_iter()).into();
|
||||
|
||||
let mut slice = coord.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
|
||||
slice[dim] = 0..input_dim;
|
||||
|
||||
let mut sliced_input = input.get_slice(&slice)?;
|
||||
sliced_input.flatten();
|
||||
|
||||
let index_valtensor: ValTensor<F> = Tensor::from([index_val.clone()].into_iter()).into();
|
||||
|
||||
let mask = equals(config, region, &[index_valtensor, indices.clone()])?;
|
||||
|
||||
let res = iff(config, region, &[mask, src_valtensor, sliced_input])?;
|
||||
|
||||
let input_cartesian_coord = slice.into_iter().multi_cartesian_product();
|
||||
let mutable_input_inner = input.get_inner_tensor_mut()?;
|
||||
|
||||
for (i, r) in res.get_inner_tensor()?.iter().enumerate() {
|
||||
let coord = input_cartesian_coord
|
||||
.clone()
|
||||
.nth(i)
|
||||
.ok_or("invalid coord")?;
|
||||
*mutable_input_inner.get_mut(&coord) = r.clone();
|
||||
}
|
||||
Ok(())
|
||||
res.iter()
|
||||
.map(|x| Value::known(i128_to_felt(*x)))
|
||||
.collect::<Tensor<Value<F>>>()
|
||||
.into()
|
||||
} else {
|
||||
Tensor::new(
|
||||
Some(&vec![Value::<F>::unknown(); input.len()]),
|
||||
&[input.len()],
|
||||
)?
|
||||
.into()
|
||||
};
|
||||
|
||||
output
|
||||
.iter_mut()
|
||||
.enumerate()
|
||||
.map(|(i, _)| inner_loop_function(i, region))
|
||||
.collect::<Result<Vec<()>, Box<dyn Error>>>()?;
|
||||
// assign the claimed output
|
||||
let mut claimed_output = region.assign(&config.custom_gates.output, &claimed_output)?;
|
||||
region.increment(claimed_output.len());
|
||||
claimed_output.reshape(input.dims())?;
|
||||
|
||||
Ok(input)
|
||||
// scatter elements is the inverse of gather elements
|
||||
let (gather_src, linear_index) = gather_elements(
|
||||
config,
|
||||
region,
|
||||
&[claimed_output.clone(), index.clone()],
|
||||
dim,
|
||||
)?;
|
||||
|
||||
// assert this is equal to the src
|
||||
enforce_equality(config, region, &[gather_src, src])?;
|
||||
|
||||
let full_index_set: ValTensor<F> =
|
||||
Tensor::from((0..input.len() as u64).map(|x| ValType::Constant(F::from(x)))).into();
|
||||
let input_indices = get_missing_set_elements(
|
||||
config,
|
||||
region,
|
||||
&[linear_index, full_index_set.clone()],
|
||||
true,
|
||||
)?;
|
||||
|
||||
claimed_output.flatten();
|
||||
let (gather_input, _) = gather_elements(
|
||||
config,
|
||||
region,
|
||||
&[claimed_output.clone(), input_indices.clone()],
|
||||
0,
|
||||
)?;
|
||||
// assert this is a subset of the input
|
||||
dynamic_lookup(
|
||||
config,
|
||||
region,
|
||||
&[input_indices, gather_input],
|
||||
&[full_index_set, input.clone()],
|
||||
)?;
|
||||
|
||||
claimed_output.reshape(input_dims)?;
|
||||
|
||||
Ok(claimed_output)
|
||||
}
|
||||
|
||||
/// sum accumulated layout
|
||||
@@ -1488,17 +1460,11 @@ pub(crate) fn argmax_axes<F: PrimeField + TensorType + PartialOrd>(
|
||||
dim: usize,
|
||||
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
// these will be assigned as constants
|
||||
let mut indices = Tensor::from((0..values[0].dims()[dim] as u64).map(|x| F::from(x)));
|
||||
indices.set_visibility(&crate::graph::Visibility::Fixed);
|
||||
let indices = region.assign(&config.custom_gates.inputs[1], &indices.try_into()?)?;
|
||||
region.increment(indices.len());
|
||||
|
||||
let argmax = move |config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1]|
|
||||
-> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
argmax(config, region, values, indices.clone())
|
||||
};
|
||||
let argmax =
|
||||
move |config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1]|
|
||||
-> Result<ValTensor<F>, Box<dyn Error>> { argmax(config, region, values) };
|
||||
|
||||
// calculate value of output
|
||||
axes_wise_op(config, region, values, &[dim], argmax)
|
||||
@@ -1524,18 +1490,12 @@ pub(crate) fn argmin_axes<F: PrimeField + TensorType + PartialOrd>(
|
||||
dim: usize,
|
||||
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
// calculate value of output
|
||||
// these will be assigned as constants
|
||||
let mut indices = Tensor::from((0..values[0].dims()[dim] as u64).map(|x| F::from(x)));
|
||||
indices.set_visibility(&crate::graph::Visibility::Fixed);
|
||||
let indices = region.assign(&config.custom_gates.inputs[1], &indices.try_into()?)?;
|
||||
region.increment(indices.len());
|
||||
|
||||
let argmin = move |config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1]|
|
||||
-> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
argmin(config, region, values, indices.clone())
|
||||
};
|
||||
let argmin =
|
||||
move |config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1]|
|
||||
-> Result<ValTensor<F>, Box<dyn Error>> { argmin(config, region, values) };
|
||||
|
||||
axes_wise_op(config, region, values, &[dim], argmin)
|
||||
}
|
||||
@@ -1851,7 +1811,8 @@ pub(crate) fn equals_zero<F: PrimeField + TensorType + PartialOrd>(
|
||||
// take the product of diff and output
|
||||
let prod_check = pairwise(config, region, &[values, output.clone()], BaseOp::Mult)?;
|
||||
|
||||
is_zero_identity(config, region, &[prod_check], false)?;
|
||||
let zero_tensor = create_zero_tensor(prod_check.len());
|
||||
enforce_equality(config, region, &[prod_check, zero_tensor])?;
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
@@ -1963,13 +1924,7 @@ pub(crate) fn sumpool<F: PrimeField + TensorType + PartialOrd>(
|
||||
.map(|coord| {
|
||||
let (b, i) = (coord[0], coord[1]);
|
||||
let input = values[0].get_slice(&[b..b + 1, i..i + 1])?;
|
||||
let output = conv(
|
||||
config,
|
||||
region,
|
||||
&[input, kernel.clone()],
|
||||
padding,
|
||||
stride,
|
||||
)?;
|
||||
let output = conv(config, region, &[input, kernel.clone()], padding, stride)?;
|
||||
res.push(output);
|
||||
Ok(())
|
||||
})
|
||||
@@ -2448,38 +2403,6 @@ pub(crate) fn identity<F: PrimeField + TensorType + PartialOrd>(
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
/// is zero identity constraint.
|
||||
pub(crate) fn is_zero_identity<F: PrimeField + TensorType + PartialOrd>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
assign: bool,
|
||||
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
let output = if assign || !values[0].get_const_indices()?.is_empty() {
|
||||
let output = region.assign(&config.custom_gates.output, &values[0])?;
|
||||
region.increment(output.len());
|
||||
output
|
||||
} else {
|
||||
values[0].clone()
|
||||
};
|
||||
// Enable the selectors
|
||||
if !region.is_dummy() {
|
||||
(0..output.len())
|
||||
.map(|j| {
|
||||
let index = region.linear_coord() - j - 1;
|
||||
|
||||
let (x, y, z) = config.custom_gates.output.cartesian_coord(index);
|
||||
let selector = config.custom_gates.selectors.get(&(BaseOp::IsZero, x, y));
|
||||
|
||||
region.enable(selector, z)?;
|
||||
Ok(())
|
||||
})
|
||||
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
|
||||
}
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
/// Boolean identity constraint. Usually used to constrain an instance column to an advice so the returned cells / values can be operated upon.
|
||||
pub(crate) fn boolean_identity<F: PrimeField + TensorType + PartialOrd>(
|
||||
config: &BaseConfig<F>,
|
||||
@@ -2747,7 +2670,6 @@ pub(crate) fn argmax<F: PrimeField + TensorType + PartialOrd>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
indices: ValTensor<F>,
|
||||
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
// this is safe because we later constrain it
|
||||
let argmax = values[0]
|
||||
@@ -2770,7 +2692,6 @@ pub(crate) fn argmax<F: PrimeField + TensorType + PartialOrd>(
|
||||
config,
|
||||
region,
|
||||
&[values[0].clone(), assigned_argmax.clone()],
|
||||
indices,
|
||||
)?;
|
||||
|
||||
let max_val = max(config, region, &[values[0].clone()])?;
|
||||
@@ -2785,7 +2706,6 @@ pub(crate) fn argmin<F: PrimeField + TensorType + PartialOrd>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
indices: ValTensor<F>,
|
||||
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
// this is safe because we later constrain it
|
||||
let argmin = values[0]
|
||||
@@ -2809,7 +2729,6 @@ pub(crate) fn argmin<F: PrimeField + TensorType + PartialOrd>(
|
||||
config,
|
||||
region,
|
||||
&[values[0].clone(), assigned_argmin.clone()],
|
||||
indices,
|
||||
)?;
|
||||
let min_val = min(config, region, &[values[0].clone()])?;
|
||||
|
||||
|
||||
@@ -276,7 +276,7 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
|
||||
if let Some(idx) = constant_idx {
|
||||
tensor::ops::gather_elements(values[0].get_inner_tensor()?, idx, *dim)?.into()
|
||||
} else {
|
||||
layouts::gather_elements(config, region, values[..].try_into()?, *dim)?
|
||||
layouts::gather_elements(config, region, values[..].try_into()?, *dim)?.0
|
||||
}
|
||||
}
|
||||
PolyOp::ScatterElements { dim, constant_idx } => {
|
||||
|
||||
@@ -402,9 +402,6 @@ pub enum Commands {
|
||||
/// Number of logrows to use for srs. Overrides settings_path if specified.
|
||||
#[arg(long, default_value = None)]
|
||||
logrows: Option<u32>,
|
||||
/// Check mode for SRS. Verifies downloaded srs is valid. Set to unsafe for speed.
|
||||
#[arg(long, default_value = DEFAULT_CHECKMODE)]
|
||||
check: CheckMode,
|
||||
},
|
||||
/// Loads model and input and runs mock prover (for testing)
|
||||
Mock {
|
||||
|
||||
@@ -159,8 +159,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
srs_path,
|
||||
settings_path,
|
||||
logrows,
|
||||
check,
|
||||
} => get_srs_cmd(srs_path, settings_path, logrows, check).await,
|
||||
} => get_srs_cmd(srs_path, settings_path, logrows).await,
|
||||
Commands::Table { model, args } => table(model, args),
|
||||
#[cfg(feature = "render")]
|
||||
Commands::RenderCircuit {
|
||||
@@ -492,23 +491,28 @@ async fn fetch_srs(uri: &str) -> Result<Vec<u8>, Box<dyn Error>> {
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
fn check_srs_hash(logrows: u32, srs_path: Option<PathBuf>) -> Result<String, Box<dyn Error>> {
|
||||
pub(crate) fn get_file_hash(path: &PathBuf) -> Result<String, Box<dyn Error>> {
|
||||
use std::io::Read;
|
||||
|
||||
let path = get_srs_path(logrows, srs_path);
|
||||
let file = std::fs::File::open(path.clone())?;
|
||||
let file = std::fs::File::open(path)?;
|
||||
let mut reader = std::io::BufReader::new(file);
|
||||
let mut buffer = vec![];
|
||||
let mut reader = std::io::BufReader::with_capacity(*EZKL_BUF_CAPACITY, file);
|
||||
let bytes_read = reader.read_to_end(&mut buffer)?;
|
||||
|
||||
info!(
|
||||
"read {} bytes from SRS file (vector of len = {})",
|
||||
"read {} bytes from file (vector of len = {})",
|
||||
bytes_read,
|
||||
buffer.len()
|
||||
);
|
||||
|
||||
let hash = sha256::digest(buffer);
|
||||
info!("SRS hash: {}", hash);
|
||||
info!("file hash: {}", hash);
|
||||
|
||||
Ok(hash)
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
fn check_srs_hash(logrows: u32, srs_path: Option<PathBuf>) -> Result<String, Box<dyn Error>> {
|
||||
let path = get_srs_path(logrows, srs_path);
|
||||
let hash = get_file_hash(&path)?;
|
||||
|
||||
let predefined_hash = match { crate::srs_sha::PUBLIC_SRS_SHA256_HASHES.get(&logrows) } {
|
||||
Some(h) => h,
|
||||
@@ -532,7 +536,6 @@ pub(crate) async fn get_srs_cmd(
|
||||
srs_path: Option<PathBuf>,
|
||||
settings_path: Option<PathBuf>,
|
||||
logrows: Option<u32>,
|
||||
check_mode: CheckMode,
|
||||
) -> Result<String, Box<dyn Error>> {
|
||||
// logrows overrides settings
|
||||
|
||||
@@ -560,21 +563,20 @@ pub(crate) async fn get_srs_cmd(
|
||||
let srs_uri = format!("{}{}", PUBLIC_SRS_URL, k);
|
||||
let mut reader = Cursor::new(fetch_srs(&srs_uri).await?);
|
||||
// check the SRS
|
||||
if matches!(check_mode, CheckMode::SAFE) {
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
let pb = init_spinner();
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pb.set_message("Validating SRS (this may take a while) ...");
|
||||
ParamsKZG::<Bn256>::read(&mut reader)?;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pb.finish_with_message("SRS validated");
|
||||
}
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
let pb = init_spinner();
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pb.set_message("Validating SRS (this may take a while) ...");
|
||||
let params = ParamsKZG::<Bn256>::read(&mut reader)?;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pb.finish_with_message("SRS validated.");
|
||||
|
||||
info!("Saving SRS to disk...");
|
||||
let mut file = std::fs::File::create(get_srs_path(k, srs_path.clone()))?;
|
||||
|
||||
let mut buffer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, &mut file);
|
||||
buffer.write_all(reader.get_ref())?;
|
||||
buffer.flush()?;
|
||||
params.write(&mut buffer)?;
|
||||
|
||||
info!("Saved SRS to disk.");
|
||||
|
||||
info!("SRS downloaded");
|
||||
} else {
|
||||
@@ -969,8 +971,8 @@ pub(crate) fn calibrate(
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// drop the gag
|
||||
|
||||
// drop the gag
|
||||
#[cfg(unix)]
|
||||
drop(_r);
|
||||
#[cfg(unix)]
|
||||
@@ -1695,7 +1697,7 @@ pub(crate) fn fuzz(
|
||||
let logrows = circuit.settings().run_args.logrows;
|
||||
|
||||
info!("setting up tests");
|
||||
|
||||
#[cfg(unix)]
|
||||
let _r = Gag::stdout()?;
|
||||
let params = gen_srs::<KZGCommitmentScheme<Bn256>>(logrows);
|
||||
|
||||
@@ -1713,6 +1715,7 @@ pub(crate) fn fuzz(
|
||||
let public_inputs = circuit.prepare_public_inputs(&data)?;
|
||||
|
||||
let strategy = KZGSingleStrategy::new(¶ms);
|
||||
#[cfg(unix)]
|
||||
std::mem::drop(_r);
|
||||
|
||||
info!("starting fuzzing");
|
||||
@@ -1903,6 +1906,7 @@ pub(crate) fn run_fuzz_fn(
|
||||
passed: &AtomicBool,
|
||||
) {
|
||||
let num_failures = AtomicI64::new(0);
|
||||
#[cfg(unix)]
|
||||
let _r = Gag::stdout().unwrap();
|
||||
|
||||
let pb = init_bar(num_runs as u64);
|
||||
@@ -1916,6 +1920,7 @@ pub(crate) fn run_fuzz_fn(
|
||||
pb.inc(1);
|
||||
});
|
||||
pb.finish_with_message("Done.");
|
||||
#[cfg(unix)]
|
||||
std::mem::drop(_r);
|
||||
info!(
|
||||
"num failures: {} out of {}",
|
||||
|
||||
@@ -484,7 +484,6 @@ fn get_srs(
|
||||
srs_path,
|
||||
settings_path,
|
||||
logrows,
|
||||
CheckMode::SAFE,
|
||||
))
|
||||
.map_err(|e| {
|
||||
let err_str = format!("Failed to get srs: {}", e);
|
||||
|
||||
@@ -494,7 +494,12 @@ impl<F: PrimeField + TensorType + PartialOrd> ValTensor<F> {
|
||||
}
|
||||
_ => return Err(Box::new(TensorError::WrongMethod)),
|
||||
};
|
||||
Ok(integer_evals.into_iter().into())
|
||||
let mut tensor: Tensor<i128> = integer_evals.into_iter().into();
|
||||
match tensor.reshape(self.dims()) {
|
||||
_ => {}
|
||||
};
|
||||
|
||||
Ok(tensor)
|
||||
}
|
||||
|
||||
/// Calls `pad_to_zero_rem` on the inner tensor.
|
||||
|
||||
Reference in New Issue
Block a user