mirror of
https://github.com/zkonduit/ezkl.git
synced 2026-01-13 00:08:12 -05:00
Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
469f2d2e71 | ||
|
|
352812b9ac |
@@ -1,7 +1,7 @@
|
||||
import ezkl
|
||||
|
||||
project = 'ezkl'
|
||||
release = '0.0.0'
|
||||
release = '16.2.8'
|
||||
version = release
|
||||
|
||||
|
||||
|
||||
@@ -30,6 +30,8 @@ use crate::{
|
||||
use super::*;
|
||||
use crate::circuit::ops::lookup::LookupOp;
|
||||
|
||||
const ASCII_ALPHABET: &str = "abcdefghijklmnopqrstuvwxyz";
|
||||
|
||||
/// Calculate the L1 distance between two tensors.
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
@@ -2671,9 +2673,7 @@ pub fn greater<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
rhs.expand(&broadcasted_shape)?;
|
||||
|
||||
let diff = pairwise(config, region, &[lhs, rhs], BaseOp::Sub)?;
|
||||
|
||||
let sign = sign(config, region, &[diff])?;
|
||||
|
||||
equals(config, region, &[sign, create_unit_tensor(1)])
|
||||
}
|
||||
|
||||
@@ -5286,75 +5286,72 @@ pub(crate) fn decompose<F: PrimeField + TensorType + PartialOrd + std::hash::Has
|
||||
base: &usize,
|
||||
n: &usize,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
let input = values[0].clone();
|
||||
let mut input = values[0].clone();
|
||||
|
||||
let is_assigned = !input.all_prev_assigned();
|
||||
|
||||
let bases: ValTensor<F> = Tensor::from(
|
||||
(0..*n)
|
||||
.rev()
|
||||
.map(|x| ValType::Constant(integer_rep_to_felt(base.pow(x as u32) as IntegerRep))),
|
||||
if !is_assigned {
|
||||
input = region.assign(&config.custom_gates.inputs[0], &input)?;
|
||||
}
|
||||
|
||||
let mut bases: ValTensor<F> = Tensor::from(
|
||||
// repeat it input.len() times
|
||||
(0..input.len()).flat_map(|_| {
|
||||
(0..*n)
|
||||
.rev()
|
||||
.map(|x| ValType::Constant(integer_rep_to_felt(base.pow(x as u32) as IntegerRep)))
|
||||
}),
|
||||
)
|
||||
.into();
|
||||
let mut bases_dims = input.dims().to_vec();
|
||||
bases_dims.push(*n);
|
||||
bases.reshape(&bases_dims)?;
|
||||
|
||||
let cartesian_coord = input
|
||||
.dims()
|
||||
.iter()
|
||||
.map(|x| 0..*x)
|
||||
.multi_cartesian_product()
|
||||
.collect::<Vec<_>>();
|
||||
let mut decomposed_dims = input.dims().to_vec();
|
||||
decomposed_dims.push(*n + 1);
|
||||
|
||||
let mut output: Tensor<Tensor<ValType<F>>> = Tensor::new(None, input.dims())?;
|
||||
let claimed_output = if region.witness_gen() {
|
||||
input.decompose(*base, *n)?
|
||||
} else {
|
||||
let decomposed_len = decomposed_dims.iter().product();
|
||||
let claimed_output = Tensor::new(
|
||||
Some(&vec![ValType::Value(Value::unknown()); decomposed_len]),
|
||||
&decomposed_dims,
|
||||
)?;
|
||||
|
||||
let inner_loop_function =
|
||||
|i: usize, region: &mut RegionCtx<F>| -> Result<Tensor<ValType<F>>, CircuitError> {
|
||||
let coord = cartesian_coord[i].clone();
|
||||
let slice = coord.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
|
||||
let mut sliced_input = input.get_slice(&slice)?;
|
||||
sliced_input.flatten();
|
||||
claimed_output.into()
|
||||
};
|
||||
region.assign(&config.custom_gates.output, &claimed_output)?;
|
||||
region.increment(claimed_output.len());
|
||||
|
||||
if !is_assigned {
|
||||
sliced_input = region.assign(&config.custom_gates.inputs[0], &sliced_input)?;
|
||||
}
|
||||
let input_slice = input.dims().iter().map(|x| 0..*x).collect::<Vec<_>>();
|
||||
let mut sign_slice = input_slice.clone();
|
||||
sign_slice.push(0..1);
|
||||
let mut rest_slice = input_slice.clone();
|
||||
rest_slice.push(1..n + 1);
|
||||
|
||||
let mut claimed_output_slice = if region.witness_gen() {
|
||||
sliced_input.decompose(*base, *n)?
|
||||
} else {
|
||||
Tensor::from(vec![ValType::Value(Value::unknown()); *n + 1].into_iter()).into()
|
||||
};
|
||||
let sign = claimed_output.get_slice(&sign_slice)?;
|
||||
let rest = claimed_output.get_slice(&rest_slice)?;
|
||||
|
||||
claimed_output_slice =
|
||||
region.assign(&config.custom_gates.inputs[1], &claimed_output_slice)?;
|
||||
claimed_output_slice.flatten();
|
||||
let sign = range_check(config, region, &[sign], &(-1, 1))?;
|
||||
let rest = range_check(config, region, &[rest], &(0, (*base - 1) as i128))?;
|
||||
|
||||
region.increment(claimed_output_slice.len());
|
||||
// equation needs to be constructed as ij,ij->i but for arbitrary n dims we need to construct this dynamically
|
||||
// indices should map in order of the alphabet
|
||||
// start with lhs
|
||||
let lhs = ASCII_ALPHABET.chars().take(rest.dims().len()).join("");
|
||||
let rhs = ASCII_ALPHABET.chars().take(rest.dims().len() - 1).join("");
|
||||
let equation = format!("{},{}->{}", lhs, lhs, rhs);
|
||||
|
||||
// get the sign bit and make sure it is valid
|
||||
let sign = claimed_output_slice.first()?;
|
||||
let sign = range_check(config, region, &[sign], &(-1, 1))?;
|
||||
// now add the rhs
|
||||
|
||||
// get the rest of the thing and make sure it is in the correct range
|
||||
let rest = claimed_output_slice.get_slice(&[1..claimed_output_slice.len()])?;
|
||||
let prod_decomp = einsum(config, region, &[rest.clone(), bases], &equation)?;
|
||||
|
||||
let rest = range_check(config, region, &[rest], &(0, (base - 1) as i128))?;
|
||||
let signed_decomp = pairwise(config, region, &[prod_decomp, sign], BaseOp::Mult)?;
|
||||
|
||||
let prod_decomp = dot(config, region, &[rest, bases.clone()])?;
|
||||
enforce_equality(config, region, &[input, signed_decomp])?;
|
||||
|
||||
let signed_decomp = pairwise(config, region, &[prod_decomp, sign], BaseOp::Mult)?;
|
||||
|
||||
enforce_equality(config, region, &[sliced_input, signed_decomp])?;
|
||||
|
||||
Ok(claimed_output_slice.get_inner_tensor()?.clone())
|
||||
};
|
||||
|
||||
region.apply_in_loop(&mut output, inner_loop_function)?;
|
||||
|
||||
let mut combined_output = output.combine()?;
|
||||
let mut output_dims = input.dims().to_vec();
|
||||
output_dims.push(*n + 1);
|
||||
combined_output.reshape(&output_dims)?;
|
||||
|
||||
Ok(combined_output.into())
|
||||
Ok(claimed_output)
|
||||
}
|
||||
|
||||
pub(crate) fn sign<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
|
||||
Reference in New Issue
Block a user