Compare commits

...

3 Commits

Author SHA1 Message Date
dante
ea535e2ecd refactor: use linear index constraints for gather and scatter (#735) 2024-03-09 18:00:21 +00:00
Alexander Camuto
f8aa91ed08 fix: windows compile 2024-03-06 11:40:44 +00:00
dante
a59e3780b2 chore: rm recip_int helper (#733) 2024-03-05 21:51:14 +00:00
8 changed files with 266 additions and 354 deletions

View File

@@ -17,7 +17,6 @@ pub enum BaseOp {
Sub,
SumInit,
Sum,
IsZero,
IsBoolean,
}
@@ -35,7 +34,6 @@ impl BaseOp {
BaseOp::Add => a + b,
BaseOp::Sub => a - b,
BaseOp::Mult => a * b,
BaseOp::IsZero => b,
BaseOp::IsBoolean => b,
_ => panic!("nonaccum_f called on accumulating operation"),
}
@@ -76,7 +74,6 @@ impl BaseOp {
BaseOp::Mult => "MULT",
BaseOp::Sum => "SUM",
BaseOp::SumInit => "SUMINIT",
BaseOp::IsZero => "ISZERO",
BaseOp::IsBoolean => "ISBOOLEAN",
}
}
@@ -93,7 +90,6 @@ impl BaseOp {
BaseOp::Mult => (0, 1),
BaseOp::Sum => (-1, 2),
BaseOp::SumInit => (0, 1),
BaseOp::IsZero => (0, 1),
BaseOp::IsBoolean => (0, 1),
}
}
@@ -110,7 +106,6 @@ impl BaseOp {
BaseOp::Mult => 2,
BaseOp::Sum => 1,
BaseOp::SumInit => 1,
BaseOp::IsZero => 0,
BaseOp::IsBoolean => 0,
}
}
@@ -127,7 +122,6 @@ impl BaseOp {
BaseOp::SumInit => 0,
BaseOp::CumProd => 1,
BaseOp::CumProdInit => 0,
BaseOp::IsZero => 0,
BaseOp::IsBoolean => 0,
}
}

View File

@@ -387,7 +387,6 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
nonaccum_selectors.insert((BaseOp::Add, i, j), meta.selector());
nonaccum_selectors.insert((BaseOp::Sub, i, j), meta.selector());
nonaccum_selectors.insert((BaseOp::Mult, i, j), meta.selector());
nonaccum_selectors.insert((BaseOp::IsZero, i, j), meta.selector());
nonaccum_selectors.insert((BaseOp::IsBoolean, i, j), meta.selector());
}
}
@@ -432,12 +431,6 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
vec![(output.clone()) * (output.clone() - Expression::Constant(F::from(1)))]
}
BaseOp::IsZero => {
let expected_output: Tensor<Expression<F>> = output
.query_rng(meta, *block_idx, *inner_col_idx, 0, 1)
.expect("non accum: output query failed");
vec![expected_output[base_op.constraint_idx()].clone()]
}
_ => {
let expected_output: Tensor<Expression<F>> = output
.query_rng(meta, *block_idx, *inner_col_idx, rotation_offset, rng)

View File

@@ -132,41 +132,6 @@ pub(crate) fn div<F: PrimeField + TensorType + PartialOrd>(
Ok(claimed_output)
}
fn recip_int<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
input: &[ValTensor<F>; 1],
) -> Result<ValTensor<F>, Box<dyn Error>> {
// assert is boolean
let zero_inverse_val = tensor::ops::nonlinearities::zero_recip(1.0)[0];
// get values where input is 0
let zero_mask = equals_zero(config, region, input)?;
let zero_mask_minus_one = pairwise(
config,
region,
&[zero_mask.clone(), create_unit_tensor(1)],
BaseOp::Sub,
)?;
let zero_inverse_val = pairwise(
config,
region,
&[
zero_mask,
create_constant_tensor(i128_to_felt(zero_inverse_val), 1),
],
BaseOp::Mult,
)?;
pairwise(
config,
region,
&[zero_mask_minus_one, zero_inverse_val],
BaseOp::Add,
)
}
/// recip accumulated layout
pub(crate) fn recip<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
@@ -175,10 +140,6 @@ pub(crate) fn recip<F: PrimeField + TensorType + PartialOrd>(
input_scale: F,
output_scale: F,
) -> Result<ValTensor<F>, Box<dyn Error>> {
if output_scale == F::ONE || output_scale == F::ZERO {
return recip_int(config, region, value);
}
let input = value[0].clone();
let input_dims = input.dims();
@@ -188,8 +149,11 @@ pub(crate) fn recip<F: PrimeField + TensorType + PartialOrd>(
// range_check_bracket is min of input_scale * output_scale and 2^F::S - 3
let range_check_len = std::cmp::min(integer_output_scale, 2_i128.pow(F::S - 4));
let input_scale_ratio =
i128_to_felt(integer_input_scale * integer_output_scale / range_check_len);
let input_scale_ratio = if range_check_len > 0 {
i128_to_felt(integer_input_scale * integer_output_scale / range_check_len)
} else {
F::ONE
};
let range_check_bracket = range_check_len / 2;
@@ -234,11 +198,7 @@ pub(crate) fn recip<F: PrimeField + TensorType + PartialOrd>(
let equal_zero_mask = equals_zero(config, region, &[input.clone()])?;
let equal_inverse_mask = equals(
config,
region,
&[claimed_output.clone(), zero_inverse],
)?;
let equal_inverse_mask = equals(config, region, &[claimed_output.clone(), zero_inverse])?;
// assert the two masks are equal
enforce_equality(
@@ -249,12 +209,7 @@ pub(crate) fn recip<F: PrimeField + TensorType + PartialOrd>(
let unit_scale = create_constant_tensor(i128_to_felt(range_check_len), 1);
let unit_mask = pairwise(
config,
region,
&[equal_zero_mask, unit_scale],
BaseOp::Mult,
)?;
let unit_mask = pairwise(config, region, &[equal_zero_mask, unit_scale], BaseOp::Mult)?;
// now add the unit mask to the rebased_div
let rebased_offset_div = pairwise(config, region, &[rebased_div, unit_mask], BaseOp::Add)?;
@@ -691,14 +646,13 @@ fn select<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 2],
dim_indices: ValTensor<F>,
) -> Result<ValTensor<F>, Box<dyn Error>> {
let (mut input, index) = (values[0].clone(), values[1].clone());
input.flatten();
if !(dim_indices.all_prev_assigned() || region.is_dummy()) {
return Err("dim_indices must be assigned".into());
}
// these will be assigned as constants
let dim_indices: ValTensor<F> =
Tensor::from((0..input.len() as u64).map(|x| ValType::Constant(F::from(x)))).into();
let is_assigned = !input.any_unknowns()? && !index.any_unknowns()?;
@@ -974,91 +928,25 @@ pub(crate) fn gather<F: PrimeField + TensorType + PartialOrd>(
values: &[ValTensor<F>; 2],
dim: usize,
) -> Result<ValTensor<F>, Box<dyn Error>> {
let (mut input, mut index_clone) = (values[0].clone(), values[1].clone());
let (input, mut index_clone) = (values[0].clone(), values[1].clone());
index_clone.flatten();
if index_clone.is_singleton() {
index_clone.reshape(&[1])?;
}
let mut assigned_len = vec![];
if !input.all_prev_assigned() {
input = region.assign(&config.custom_gates.inputs[0], &input)?;
assigned_len.push(input.len());
}
if !index_clone.all_prev_assigned() {
index_clone = region.assign(&config.custom_gates.inputs[1], &index_clone)?;
assigned_len.push(index_clone.len());
}
if !assigned_len.is_empty() {
// safe to unwrap since we've just checked it has at least one element
region.increment(*assigned_len.iter().max().unwrap());
}
// Calculate the output tensor size
let input_dims = input.dims();
let mut output_size = input_dims.to_vec();
output_size[dim] = index_clone.dims()[0];
// these will be assigned as constants
let mut indices = Tensor::from((0..input.dims()[dim] as u64).map(|x| F::from(x)));
indices.set_visibility(&crate::graph::Visibility::Fixed);
let indices = region.assign(&config.custom_gates.inputs[1], &indices.try_into()?)?;
region.increment(indices.len());
let linear_index =
linearize_element_index(config, region, &[index_clone], input_dims, dim, true)?;
let mut iteration_dims = output_size.clone();
iteration_dims[dim] = 1;
let mut output = select(config, region, &[input, linear_index])?;
// Allocate memory for the output tensor
let cartesian_coord = iteration_dims
.iter()
.map(|x| 0..*x)
.multi_cartesian_product()
.collect::<Vec<_>>();
let mut results = HashMap::new();
for coord in cartesian_coord {
let mut slice = coord.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
slice[dim] = 0..input_dims[dim];
let mut sliced_input = input.get_slice(&slice)?;
sliced_input.flatten();
let res = select(
config,
region,
&[sliced_input, index_clone.clone()],
indices.clone(),
)?;
results.insert(coord, res);
}
// Allocate memory for the output tensor
let cartesian_coord = output_size
.iter()
.map(|x| 0..*x)
.multi_cartesian_product()
.collect::<Vec<_>>();
let mut output = Tensor::new(None, &output_size)?.par_enum_map(|i, _: ValType<F>| {
let coord = cartesian_coord[i].clone();
let mut key = coord.clone();
key[dim] = 0;
let result = &results.get(&key).ok_or("missing result")?;
let o = result.get_inner_tensor().map_err(|_| "missing tensor")?[coord[dim]].clone();
Ok::<ValType<F>, region::RegionError>(o)
})?;
// Reshape the output tensor
if index_clone.is_singleton() {
output_size.remove(dim);
}
output.reshape(&output_size)?;
Ok(output.into())
Ok(output)
}
/// Gather accumulated layout
@@ -1067,82 +955,173 @@ pub(crate) fn gather_elements<F: PrimeField + TensorType + PartialOrd>(
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 2],
dim: usize,
) -> Result<ValTensor<F>, Box<dyn Error>> {
let (mut input, mut index) = (values[0].clone(), values[1].clone());
) -> Result<(ValTensor<F>, ValTensor<F>), Box<dyn Error>> {
let (input, index) = (values[0].clone(), values[1].clone());
assert_eq!(input.dims().len(), index.dims().len());
if !input.all_prev_assigned() {
input = region.assign(&config.custom_gates.inputs[0], &input)?;
}
if !index.all_prev_assigned() {
index = region.assign(&config.custom_gates.inputs[1], &index)?;
}
region.increment(std::cmp::max(input.len(), index.len()));
// Calculate the output tensor size
let input_dims = input.dims();
let output_size = index.dims().to_vec();
// these will be assigned as constants
let mut indices = Tensor::from((0..input_dims[dim] as u64).map(|x| F::from(x)));
indices.set_visibility(&crate::graph::Visibility::Fixed);
let indices = region.assign(&config.custom_gates.inputs[1], &indices.try_into()?)?;
region.increment(indices.len());
let linear_index = linearize_element_index(config, region, &[index], input.dims(), dim, false)?;
let mut iteration_dims = output_size.clone();
iteration_dims[dim] = 1;
let mut output = select(config, region, &[input, linear_index.clone()])?;
output.reshape(&output_size)?;
Ok((output, linear_index))
}
/// Takes a tensor representing a multi-dimensional index and returns a tensor representing the linearized index.
/// The linearized index is the index of the element in the flattened tensor.
/// FOr instance if the dims is [3,5,2], the linearized index of [2] at dim 1 is 2*5 + 3 = 13
pub(crate) fn linearize_element_index<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
dims: &[usize],
dim: usize,
is_flat_index: bool,
) -> Result<ValTensor<F>, Box<dyn Error>> {
let index = values[0].clone();
if !is_flat_index {
assert_eq!(index.dims().len(), dims.len());
// if the index is already flat, return it
if index.dims().len() == 1 {
return Ok(index);
}
}
let dim_multiplier: Tensor<usize> = Tensor::new(None, &[dims.len()])?;
let dim_multiplier: Tensor<F> = dim_multiplier.par_enum_map(|i, _| {
let mut res = 1;
for dim in dims.iter().skip(i + 1) {
res *= dim;
}
Ok::<_, region::RegionError>(F::from(res as u64))
})?;
let iteration_dims = if is_flat_index {
let mut dims = dims.to_vec();
dims[dim] = index.len();
dims
} else {
index.dims().to_vec()
};
// Allocate memory for the output tensor
let cartesian_coord = iteration_dims
.iter()
.map(|x| 0..*x)
.multi_cartesian_product()
.collect::<Vec<_>>();
let mut results = HashMap::new();
let val_dim_multiplier: ValTensor<F> = dim_multiplier
.get_slice(&[dim..dim + 1])?
.map(|x| ValType::Constant(x))
.into();
for coord in cartesian_coord {
let mut slice = coord.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
slice[dim] = 0..input_dims[dim];
let mut output = Tensor::new(None, &[cartesian_coord.len()])?;
let mut sliced_input = input.get_slice(&slice)?;
sliced_input.flatten();
let inner_loop_function = |i: usize, region: &mut RegionCtx<'_, F>| {
let coord = cartesian_coord[i].clone();
let slice: Vec<Range<usize>> = if is_flat_index {
coord[dim..dim + 1].iter().map(|x| *x..*x + 1).collect()
} else {
coord.iter().map(|x| *x..*x + 1).collect::<Vec<_>>()
};
slice[dim] = 0..output_size[dim];
let mut sliced_index = index.get_slice(&slice)?;
sliced_index.flatten();
let index_val = index.get_slice(&slice)?;
let res = select(
let mut const_offset = F::ZERO;
for i in 0..dims.len() {
if i != dim {
const_offset += F::from(coord[i] as u64) * dim_multiplier[i];
}
}
let const_offset = create_constant_tensor(const_offset, 1);
let res = pairwise(
config,
region,
&[sliced_input, sliced_index],
indices.clone(),
&[index_val, val_dim_multiplier.clone()],
BaseOp::Mult,
)?;
results.insert(coord, res);
}
let res = pairwise(config, region, &[res, const_offset], BaseOp::Add)?;
// Allocate memory for the output tensor
let cartesian_coord = output_size
.iter()
.map(|x| 0..*x)
.multi_cartesian_product()
.collect::<Vec<_>>();
Ok(res.get_inner_tensor()?[0].clone())
};
let output = Tensor::new(None, &output_size)?.par_enum_map(|i, _: ValType<F>| {
let coord = cartesian_coord[i].clone();
let mut key = coord.clone();
key[dim] = 0;
let result = &results.get(&key).ok_or("missing result")?;
let o = result.get_inner_tensor().map_err(|_| "missing tensor")?[coord[dim]].clone();
Ok::<ValType<F>, region::RegionError>(o)
})?;
region.apply_in_loop(&mut output, inner_loop_function)?;
Ok(output.into())
}
pub(crate) fn get_missing_set_elements<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 2],
ordered: bool,
) -> Result<ValTensor<F>, Box<dyn Error>> {
let (mut input, fullset) = (values[0].clone(), values[1].clone());
let set_len = fullset.len();
input.flatten();
let is_assigned = !input.any_unknowns()? && !fullset.any_unknowns()?;
let mut claimed_output: ValTensor<F> = if is_assigned {
let input_evals = input.get_int_evals()?;
let mut fullset_evals = fullset.get_int_evals()?.into_iter().collect::<Vec<_>>();
// get the difference between the two vectors
for eval in input_evals.iter() {
// delete first occurence of that value
if let Some(pos) = fullset_evals.iter().position(|x| x == eval) {
fullset_evals.remove(pos);
}
}
// if fullset + input is the same length, then input is a subset of fullset, else randomly delete elements, this is a patch for
// the fact that we can't have a tensor of unknowns when using constant during gen-settings
if fullset_evals.len() != set_len - input.len() {
fullset_evals.truncate(set_len - input.len());
}
fullset_evals
.iter()
.map(|x| Value::known(i128_to_felt(*x)))
.collect::<Tensor<Value<F>>>()
.into()
} else {
let dim = fullset.len() - input.len();
Tensor::new(Some(&vec![Value::<F>::unknown(); dim]), &[dim])?.into()
};
// assign the claimed output
claimed_output = region.assign(&config.custom_gates.output, &claimed_output)?;
// input and claimed output should be the shuffles of fullset
// concatentate input and claimed output
let input_and_claimed_output = input.concat(claimed_output.clone())?;
// assert that this is a permutation/shuffle
shuffles(
config,
region,
&[input_and_claimed_output.clone()],
&[fullset.clone()],
)?;
if ordered {
// assert that the claimed output is sorted
claimed_output = _sort_ascending(config, region, &[claimed_output])?;
}
Ok(claimed_output)
}
/// Gather accumulated layout
pub(crate) fn scatter_elements<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
@@ -1150,88 +1129,81 @@ pub(crate) fn scatter_elements<F: PrimeField + TensorType + PartialOrd>(
values: &[ValTensor<F>; 3],
dim: usize,
) -> Result<ValTensor<F>, Box<dyn Error>> {
let (mut input, mut index, mut src) = (values[0].clone(), values[1].clone(), values[2].clone());
let (input, mut index, src) = (values[0].clone(), values[1].clone(), values[2].clone());
assert_eq!(input.dims().len(), index.dims().len());
let mut assigned_len = vec![];
let input_dims = input.dims();
if !input.all_prev_assigned() {
input = region.assign(&config.custom_gates.inputs[0], &input)?;
assigned_len.push(input.len());
}
if !index.all_prev_assigned() {
index = region.assign(&config.custom_gates.inputs[1], &index)?;
assigned_len.push(index.len());
}
if !src.all_prev_assigned() {
src = region.assign(&config.custom_gates.output, &src)?;
assigned_len.push(src.len());
region.increment(index.len());
}
if !assigned_len.is_empty() {
// safe to unwrap since we've just checked it has at least one element
region.increment(*assigned_len.iter().max().unwrap());
}
let is_assigned = !input.any_unknowns()? && !index.any_unknowns()? && !src.any_unknowns()?;
// Calculate the output tensor size
let input_dim = input.dims()[dim];
let output_size = index.dims().to_vec();
let claimed_output: ValTensor<F> = if is_assigned {
let input_inner = input.get_int_evals()?;
let index_inner = index.get_int_evals()?.map(|x| x as usize);
let src_inner = src.get_int_evals()?;
// these will be assigned as constants
let mut indices = Tensor::from((0..input_dim as u64).map(|x| F::from(x)));
indices.set_visibility(&crate::graph::Visibility::Fixed);
let indices = region.assign(&config.custom_gates.inputs[1], &indices.try_into()?)?;
region.increment(indices.len());
let res = tensor::ops::scatter(&input_inner, &index_inner, &src_inner, dim)?;
// Allocate memory for the output tensor
let cartesian_coord = output_size
.iter()
.map(|x| 0..*x)
.multi_cartesian_product()
.collect::<Vec<_>>();
let mut output: Tensor<()> = Tensor::new(None, &output_size)?;
let mut inner_loop_function = |i: usize, region: &mut RegionCtx<'_, F>| {
let coord = cartesian_coord[i].clone();
let index_val = index.get_inner_tensor()?.get(&coord);
let src_val = src.get_inner_tensor()?.get(&coord);
let src_valtensor: ValTensor<F> = Tensor::from([src_val.clone()].into_iter()).into();
let mut slice = coord.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
slice[dim] = 0..input_dim;
let mut sliced_input = input.get_slice(&slice)?;
sliced_input.flatten();
let index_valtensor: ValTensor<F> = Tensor::from([index_val.clone()].into_iter()).into();
let mask = equals(config, region, &[index_valtensor, indices.clone()])?;
let res = iff(config, region, &[mask, src_valtensor, sliced_input])?;
let input_cartesian_coord = slice.into_iter().multi_cartesian_product();
let mutable_input_inner = input.get_inner_tensor_mut()?;
for (i, r) in res.get_inner_tensor()?.iter().enumerate() {
let coord = input_cartesian_coord
.clone()
.nth(i)
.ok_or("invalid coord")?;
*mutable_input_inner.get_mut(&coord) = r.clone();
}
Ok(())
res.iter()
.map(|x| Value::known(i128_to_felt(*x)))
.collect::<Tensor<Value<F>>>()
.into()
} else {
Tensor::new(
Some(&vec![Value::<F>::unknown(); input.len()]),
&[input.len()],
)?
.into()
};
output
.iter_mut()
.enumerate()
.map(|(i, _)| inner_loop_function(i, region))
.collect::<Result<Vec<()>, Box<dyn Error>>>()?;
// assign the claimed output
let mut claimed_output = region.assign(&config.custom_gates.output, &claimed_output)?;
region.increment(claimed_output.len());
claimed_output.reshape(input.dims())?;
Ok(input)
// scatter elements is the inverse of gather elements
let (gather_src, linear_index) = gather_elements(
config,
region,
&[claimed_output.clone(), index.clone()],
dim,
)?;
// assert this is equal to the src
enforce_equality(config, region, &[gather_src, src])?;
let full_index_set: ValTensor<F> =
Tensor::from((0..input.len() as u64).map(|x| ValType::Constant(F::from(x)))).into();
let input_indices = get_missing_set_elements(
config,
region,
&[linear_index, full_index_set.clone()],
true,
)?;
claimed_output.flatten();
let (gather_input, _) = gather_elements(
config,
region,
&[claimed_output.clone(), input_indices.clone()],
0,
)?;
// assert this is a subset of the input
dynamic_lookup(
config,
region,
&[input_indices, gather_input],
&[full_index_set, input.clone()],
)?;
claimed_output.reshape(input_dims)?;
Ok(claimed_output)
}
/// sum accumulated layout
@@ -1488,17 +1460,11 @@ pub(crate) fn argmax_axes<F: PrimeField + TensorType + PartialOrd>(
dim: usize,
) -> Result<ValTensor<F>, Box<dyn Error>> {
// these will be assigned as constants
let mut indices = Tensor::from((0..values[0].dims()[dim] as u64).map(|x| F::from(x)));
indices.set_visibility(&crate::graph::Visibility::Fixed);
let indices = region.assign(&config.custom_gates.inputs[1], &indices.try_into()?)?;
region.increment(indices.len());
let argmax = move |config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1]|
-> Result<ValTensor<F>, Box<dyn Error>> {
argmax(config, region, values, indices.clone())
};
let argmax =
move |config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1]|
-> Result<ValTensor<F>, Box<dyn Error>> { argmax(config, region, values) };
// calculate value of output
axes_wise_op(config, region, values, &[dim], argmax)
@@ -1524,18 +1490,12 @@ pub(crate) fn argmin_axes<F: PrimeField + TensorType + PartialOrd>(
dim: usize,
) -> Result<ValTensor<F>, Box<dyn Error>> {
// calculate value of output
// these will be assigned as constants
let mut indices = Tensor::from((0..values[0].dims()[dim] as u64).map(|x| F::from(x)));
indices.set_visibility(&crate::graph::Visibility::Fixed);
let indices = region.assign(&config.custom_gates.inputs[1], &indices.try_into()?)?;
region.increment(indices.len());
let argmin = move |config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1]|
-> Result<ValTensor<F>, Box<dyn Error>> {
argmin(config, region, values, indices.clone())
};
let argmin =
move |config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1]|
-> Result<ValTensor<F>, Box<dyn Error>> { argmin(config, region, values) };
axes_wise_op(config, region, values, &[dim], argmin)
}
@@ -1851,7 +1811,8 @@ pub(crate) fn equals_zero<F: PrimeField + TensorType + PartialOrd>(
// take the product of diff and output
let prod_check = pairwise(config, region, &[values, output.clone()], BaseOp::Mult)?;
is_zero_identity(config, region, &[prod_check], false)?;
let zero_tensor = create_zero_tensor(prod_check.len());
enforce_equality(config, region, &[prod_check, zero_tensor])?;
Ok(output)
}
@@ -1963,13 +1924,7 @@ pub(crate) fn sumpool<F: PrimeField + TensorType + PartialOrd>(
.map(|coord| {
let (b, i) = (coord[0], coord[1]);
let input = values[0].get_slice(&[b..b + 1, i..i + 1])?;
let output = conv(
config,
region,
&[input, kernel.clone()],
padding,
stride,
)?;
let output = conv(config, region, &[input, kernel.clone()], padding, stride)?;
res.push(output);
Ok(())
})
@@ -2448,38 +2403,6 @@ pub(crate) fn identity<F: PrimeField + TensorType + PartialOrd>(
Ok(output)
}
/// is zero identity constraint.
pub(crate) fn is_zero_identity<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
assign: bool,
) -> Result<ValTensor<F>, Box<dyn Error>> {
let output = if assign || !values[0].get_const_indices()?.is_empty() {
let output = region.assign(&config.custom_gates.output, &values[0])?;
region.increment(output.len());
output
} else {
values[0].clone()
};
// Enable the selectors
if !region.is_dummy() {
(0..output.len())
.map(|j| {
let index = region.linear_coord() - j - 1;
let (x, y, z) = config.custom_gates.output.cartesian_coord(index);
let selector = config.custom_gates.selectors.get(&(BaseOp::IsZero, x, y));
region.enable(selector, z)?;
Ok(())
})
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
}
Ok(output)
}
/// Boolean identity constraint. Usually used to constrain an instance column to an advice so the returned cells / values can be operated upon.
pub(crate) fn boolean_identity<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
@@ -2747,7 +2670,6 @@ pub(crate) fn argmax<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
indices: ValTensor<F>,
) -> Result<ValTensor<F>, Box<dyn Error>> {
// this is safe because we later constrain it
let argmax = values[0]
@@ -2770,7 +2692,6 @@ pub(crate) fn argmax<F: PrimeField + TensorType + PartialOrd>(
config,
region,
&[values[0].clone(), assigned_argmax.clone()],
indices,
)?;
let max_val = max(config, region, &[values[0].clone()])?;
@@ -2785,7 +2706,6 @@ pub(crate) fn argmin<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
indices: ValTensor<F>,
) -> Result<ValTensor<F>, Box<dyn Error>> {
// this is safe because we later constrain it
let argmin = values[0]
@@ -2809,7 +2729,6 @@ pub(crate) fn argmin<F: PrimeField + TensorType + PartialOrd>(
config,
region,
&[values[0].clone(), assigned_argmin.clone()],
indices,
)?;
let min_val = min(config, region, &[values[0].clone()])?;

View File

@@ -276,7 +276,7 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
if let Some(idx) = constant_idx {
tensor::ops::gather_elements(values[0].get_inner_tensor()?, idx, *dim)?.into()
} else {
layouts::gather_elements(config, region, values[..].try_into()?, *dim)?
layouts::gather_elements(config, region, values[..].try_into()?, *dim)?.0
}
}
PolyOp::ScatterElements { dim, constant_idx } => {

View File

@@ -402,9 +402,6 @@ pub enum Commands {
/// Number of logrows to use for srs. Overrides settings_path if specified.
#[arg(long, default_value = None)]
logrows: Option<u32>,
/// Check mode for SRS. Verifies downloaded srs is valid. Set to unsafe for speed.
#[arg(long, default_value = DEFAULT_CHECKMODE)]
check: CheckMode,
},
/// Loads model and input and runs mock prover (for testing)
Mock {

View File

@@ -159,8 +159,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
srs_path,
settings_path,
logrows,
check,
} => get_srs_cmd(srs_path, settings_path, logrows, check).await,
} => get_srs_cmd(srs_path, settings_path, logrows).await,
Commands::Table { model, args } => table(model, args),
#[cfg(feature = "render")]
Commands::RenderCircuit {
@@ -492,23 +491,28 @@ async fn fetch_srs(uri: &str) -> Result<Vec<u8>, Box<dyn Error>> {
}
#[cfg(not(target_arch = "wasm32"))]
fn check_srs_hash(logrows: u32, srs_path: Option<PathBuf>) -> Result<String, Box<dyn Error>> {
pub(crate) fn get_file_hash(path: &PathBuf) -> Result<String, Box<dyn Error>> {
use std::io::Read;
let path = get_srs_path(logrows, srs_path);
let file = std::fs::File::open(path.clone())?;
let file = std::fs::File::open(path)?;
let mut reader = std::io::BufReader::new(file);
let mut buffer = vec![];
let mut reader = std::io::BufReader::with_capacity(*EZKL_BUF_CAPACITY, file);
let bytes_read = reader.read_to_end(&mut buffer)?;
info!(
"read {} bytes from SRS file (vector of len = {})",
"read {} bytes from file (vector of len = {})",
bytes_read,
buffer.len()
);
let hash = sha256::digest(buffer);
info!("SRS hash: {}", hash);
info!("file hash: {}", hash);
Ok(hash)
}
#[cfg(not(target_arch = "wasm32"))]
fn check_srs_hash(logrows: u32, srs_path: Option<PathBuf>) -> Result<String, Box<dyn Error>> {
let path = get_srs_path(logrows, srs_path);
let hash = get_file_hash(&path)?;
let predefined_hash = match { crate::srs_sha::PUBLIC_SRS_SHA256_HASHES.get(&logrows) } {
Some(h) => h,
@@ -532,7 +536,6 @@ pub(crate) async fn get_srs_cmd(
srs_path: Option<PathBuf>,
settings_path: Option<PathBuf>,
logrows: Option<u32>,
check_mode: CheckMode,
) -> Result<String, Box<dyn Error>> {
// logrows overrides settings
@@ -560,21 +563,20 @@ pub(crate) async fn get_srs_cmd(
let srs_uri = format!("{}{}", PUBLIC_SRS_URL, k);
let mut reader = Cursor::new(fetch_srs(&srs_uri).await?);
// check the SRS
if matches!(check_mode, CheckMode::SAFE) {
#[cfg(not(target_arch = "wasm32"))]
let pb = init_spinner();
#[cfg(not(target_arch = "wasm32"))]
pb.set_message("Validating SRS (this may take a while) ...");
ParamsKZG::<Bn256>::read(&mut reader)?;
#[cfg(not(target_arch = "wasm32"))]
pb.finish_with_message("SRS validated");
}
#[cfg(not(target_arch = "wasm32"))]
let pb = init_spinner();
#[cfg(not(target_arch = "wasm32"))]
pb.set_message("Validating SRS (this may take a while) ...");
let params = ParamsKZG::<Bn256>::read(&mut reader)?;
#[cfg(not(target_arch = "wasm32"))]
pb.finish_with_message("SRS validated.");
info!("Saving SRS to disk...");
let mut file = std::fs::File::create(get_srs_path(k, srs_path.clone()))?;
let mut buffer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, &mut file);
buffer.write_all(reader.get_ref())?;
buffer.flush()?;
params.write(&mut buffer)?;
info!("Saved SRS to disk.");
info!("SRS downloaded");
} else {
@@ -969,8 +971,8 @@ pub(crate) fn calibrate(
continue;
}
}
// drop the gag
// drop the gag
#[cfg(unix)]
drop(_r);
#[cfg(unix)]
@@ -1695,7 +1697,7 @@ pub(crate) fn fuzz(
let logrows = circuit.settings().run_args.logrows;
info!("setting up tests");
#[cfg(unix)]
let _r = Gag::stdout()?;
let params = gen_srs::<KZGCommitmentScheme<Bn256>>(logrows);
@@ -1713,6 +1715,7 @@ pub(crate) fn fuzz(
let public_inputs = circuit.prepare_public_inputs(&data)?;
let strategy = KZGSingleStrategy::new(&params);
#[cfg(unix)]
std::mem::drop(_r);
info!("starting fuzzing");
@@ -1903,6 +1906,7 @@ pub(crate) fn run_fuzz_fn(
passed: &AtomicBool,
) {
let num_failures = AtomicI64::new(0);
#[cfg(unix)]
let _r = Gag::stdout().unwrap();
let pb = init_bar(num_runs as u64);
@@ -1916,6 +1920,7 @@ pub(crate) fn run_fuzz_fn(
pb.inc(1);
});
pb.finish_with_message("Done.");
#[cfg(unix)]
std::mem::drop(_r);
info!(
"num failures: {} out of {}",

View File

@@ -484,7 +484,6 @@ fn get_srs(
srs_path,
settings_path,
logrows,
CheckMode::SAFE,
))
.map_err(|e| {
let err_str = format!("Failed to get srs: {}", e);

View File

@@ -494,7 +494,12 @@ impl<F: PrimeField + TensorType + PartialOrd> ValTensor<F> {
}
_ => return Err(Box::new(TensorError::WrongMethod)),
};
Ok(integer_evals.into_iter().into())
let mut tensor: Tensor<i128> = integer_evals.into_iter().into();
match tensor.reshape(self.dims()) {
_ => {}
};
Ok(tensor)
}
/// Calls `pad_to_zero_rem` on the inner tensor.