Compare commits

...

5 Commits

Author SHA1 Message Date
dante
9c699e30cb Merge branch 'main' into ac/artifact-version-warning 2025-01-06 15:49:58 +00:00
dante
e86caca8b6 refactor: batched poly reads (#897) 2025-01-06 15:49:47 +00:00
dante
3b8e44df9b chore: version mismatch warnings for artifacts 2025-01-06 15:35:37 +00:00
dante
c839a30ae6 fix: clearer duplication functions (#895) 2024-12-31 07:28:02 -05:00
dante
352812b9ac refactor!: simplified decompose op (#892) 2024-12-30 13:44:03 -05:00
13 changed files with 216 additions and 203 deletions

4
Cargo.lock generated
View File

@@ -2377,7 +2377,7 @@ dependencies = [
[[package]]
name = "halo2_gadgets"
version = "0.2.0"
source = "git+https://github.com/zkonduit/halo2#0654e92bdf725fd44d849bfef3643870a8c7d50b"
source = "git+https://github.com/zkonduit/halo2#6d72498928cdb69ce0de9f2230d2873ca2cf5324"
dependencies = [
"arrayvec 0.7.4",
"bitvec",
@@ -2394,7 +2394,7 @@ dependencies = [
[[package]]
name = "halo2_proofs"
version = "0.3.0"
source = "git+https://github.com/zkonduit/halo2#0654e92bdf725fd44d849bfef3643870a8c7d50b#0654e92bdf725fd44d849bfef3643870a8c7d50b"
source = "git+https://github.com/zkonduit/halo2#6d72498928cdb69ce0de9f2230d2873ca2cf5324#6d72498928cdb69ce0de9f2230d2873ca2cf5324"
dependencies = [
"bincode",
"blake2b_simd",

View File

@@ -280,7 +280,10 @@ no-update = []
[patch.'https://github.com/zkonduit/halo2']
halo2_proofs = { git = "https://github.com/zkonduit/halo2#0654e92bdf725fd44d849bfef3643870a8c7d50b", package = "halo2_proofs" }
halo2_proofs = { git = "https://github.com/zkonduit/halo2#6d72498928cdb69ce0de9f2230d2873ca2cf5324", package = "halo2_proofs" }
[patch.'https://github.com/zkonduit/halo2#0654e92bdf725fd44d849bfef3643870a8c7d50b']
halo2_proofs = { git = "https://github.com/zkonduit/halo2#6d72498928cdb69ce0de9f2230d2873ca2cf5324", package = "halo2_proofs" }
[patch.crates-io]
uniffi_testing = { git = "https://github.com/ElusAegis/uniffi-rs", branch = "feat/testing-feature-build-fix" }

View File

@@ -30,6 +30,8 @@ use crate::{
use super::*;
use crate::circuit::ops::lookup::LookupOp;
const ASCII_ALPHABET: &str = "abcdefghijklmnopqrstuvwxyz";
/// Calculate the L1 distance between two tensors.
/// ```
/// use ezkl::tensor::Tensor;
@@ -418,10 +420,6 @@ pub fn dot<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
values[0].remove_indices(&mut removal_indices, true)?;
values[1].remove_indices(&mut removal_indices, true)?;
let elapsed = global_start.elapsed();
trace!("filtering const zero indices took: {:?}", elapsed);
let start = instant::Instant::now();
let mut inputs = vec![];
let block_width = config.custom_gates.output.num_inner_cols();
@@ -429,37 +427,22 @@ pub fn dot<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
for (i, input) in values.iter_mut().enumerate() {
input.pad_to_zero_rem(block_width, ValType::Constant(F::ZERO))?;
let inp = {
let (res, len) = region.assign_with_duplication(
&config.custom_gates.inputs[i],
input,
&config.check_mode,
false,
)?;
let (res, len) = region
.assign_with_duplication_unconstrained(&config.custom_gates.inputs[i], input)?;
assigned_len = len;
res.get_inner()?
};
inputs.push(inp);
}
let elapsed = start.elapsed();
trace!("assigning inputs took: {:?}", elapsed);
// Now we can assign the dot product
// time this step
let start = instant::Instant::now();
let accumulated_dot = accumulated::dot(&[inputs[0].clone(), inputs[1].clone()], block_width)?;
let elapsed = start.elapsed();
trace!("calculating accumulated dot took: {:?}", elapsed);
let start = instant::Instant::now();
let (output, output_assigned_len) = region.assign_with_duplication(
let (output, output_assigned_len) = region.assign_with_duplication_constrained(
&config.custom_gates.output,
&accumulated_dot.into(),
&config.check_mode,
true,
)?;
let elapsed = start.elapsed();
trace!("assigning output took: {:?}", elapsed);
// enable the selectors
if !region.is_dummy() {
@@ -1000,7 +983,6 @@ fn select<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 2],
) -> Result<ValTensor<F>, CircuitError> {
let start = instant::Instant::now();
let (mut input, index) = (values[0].clone(), values[1].clone());
input.flatten();
@@ -1028,9 +1010,6 @@ fn select<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
let (_, assigned_output) =
dynamic_lookup(config, region, &[index, output], &[dim_indices, input])?;
let end = start.elapsed();
trace!("select took: {:?}", end);
Ok(assigned_output)
}
@@ -1092,7 +1071,6 @@ pub(crate) fn dynamic_lookup<F: PrimeField + TensorType + PartialOrd + std::hash
lookups: &[ValTensor<F>; 2],
tables: &[ValTensor<F>; 2],
) -> Result<(ValTensor<F>, ValTensor<F>), CircuitError> {
let start = instant::Instant::now();
// if not all lookups same length err
if lookups[0].len() != lookups[1].len() {
return Err(CircuitError::MismatchedLookupLength(
@@ -1126,28 +1104,20 @@ pub(crate) fn dynamic_lookup<F: PrimeField + TensorType + PartialOrd + std::hash
}
let table_len = table_0.len();
trace!("assigning tables took: {:?}", start.elapsed());
// now create a vartensor of constants for the dynamic lookup index
let table_index = create_constant_tensor(F::from(dynamic_lookup_index as u64), table_len);
let _table_index =
region.assign_dynamic_lookup(&config.dynamic_lookups.tables[2], &table_index)?;
trace!("assigning table index took: {:?}", start.elapsed());
let lookup_0 = region.assign(&config.dynamic_lookups.inputs[0], &lookup_0)?;
let lookup_1 = region.assign(&config.dynamic_lookups.inputs[1], &lookup_1)?;
let lookup_len = lookup_0.len();
trace!("assigning lookups took: {:?}", start.elapsed());
// now set the lookup index
let lookup_index = create_constant_tensor(F::from(dynamic_lookup_index as u64), lookup_len);
let _lookup_index = region.assign(&config.dynamic_lookups.inputs[2], &lookup_index)?;
trace!("assigning lookup index took: {:?}", start.elapsed());
let mut lookup_block = 0;
if !region.is_dummy() {
@@ -1194,9 +1164,6 @@ pub(crate) fn dynamic_lookup<F: PrimeField + TensorType + PartialOrd + std::hash
region.increment_dynamic_lookup_index(1);
region.increment(lookup_len);
let end = start.elapsed();
trace!("dynamic lookup took: {:?}", end);
Ok((lookup_0, lookup_1))
}
@@ -1441,7 +1408,6 @@ pub(crate) fn linearize_element_index<F: PrimeField + TensorType + PartialOrd +
dim: usize,
is_flat_index: bool,
) -> Result<ValTensor<F>, CircuitError> {
let start_time = instant::Instant::now();
let index = values[0].clone();
if !is_flat_index {
assert_eq!(index.dims().len(), dims.len());
@@ -1515,9 +1481,6 @@ pub(crate) fn linearize_element_index<F: PrimeField + TensorType + PartialOrd +
region.apply_in_loop(&mut output, inner_loop_function)?;
let elapsed = start_time.elapsed();
trace!("linearize_element_index took: {:?}", elapsed);
Ok(output.into())
}
@@ -1949,16 +1912,11 @@ pub fn sum<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
region.flush()?;
// time this entire function run
let global_start = instant::Instant::now();
let mut values = values.clone();
// this section has been optimized to death, don't mess with it
values[0].remove_const_zero_values();
let elapsed = global_start.elapsed();
trace!("filtering const zero indices took: {:?}", elapsed);
// if empty return a const
if values[0].is_empty() {
return Ok(create_zero_tensor(1));
@@ -1970,12 +1928,8 @@ pub fn sum<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
let input = {
let mut input = values[0].clone();
input.pad_to_zero_rem(block_width, ValType::Constant(F::ZERO))?;
let (res, len) = region.assign_with_duplication(
&config.custom_gates.inputs[1],
&input,
&config.check_mode,
false,
)?;
let (res, len) =
region.assign_with_duplication_unconstrained(&config.custom_gates.inputs[1], &input)?;
assigned_len = len;
res.get_inner()?
};
@@ -1983,11 +1937,10 @@ pub fn sum<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
// Now we can assign the dot product
let accumulated_sum = accumulated::sum(&input, block_width)?;
let (output, output_assigned_len) = region.assign_with_duplication(
let (output, output_assigned_len) = region.assign_with_duplication_constrained(
&config.custom_gates.output,
&accumulated_sum.into(),
&config.check_mode,
true,
)?;
// enable the selectors
@@ -2053,13 +2006,10 @@ pub fn prod<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
) -> Result<ValTensor<F>, CircuitError> {
region.flush()?;
// time this entire function run
let global_start = instant::Instant::now();
// this section has been optimized to death, don't mess with it
let removal_indices = values[0].get_const_zero_indices();
let elapsed = global_start.elapsed();
trace!("finding const zero indices took: {:?}", elapsed);
// if empty return a const
if !removal_indices.is_empty() {
return Ok(create_zero_tensor(1));
@@ -2070,12 +2020,8 @@ pub fn prod<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
let input = {
let mut input = values[0].clone();
input.pad_to_zero_rem(block_width, ValType::Constant(F::ONE))?;
let (res, len) = region.assign_with_duplication(
&config.custom_gates.inputs[1],
&input,
&config.check_mode,
false,
)?;
let (res, len) =
region.assign_with_duplication_unconstrained(&config.custom_gates.inputs[1], &input)?;
assigned_len = len;
res.get_inner()?
};
@@ -2083,11 +2029,10 @@ pub fn prod<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
// Now we can assign the dot product
let accumulated_prod = accumulated::prod(&input, block_width)?;
let (output, output_assigned_len) = region.assign_with_duplication(
let (output, output_assigned_len) = region.assign_with_duplication_constrained(
&config.custom_gates.output,
&accumulated_prod.into(),
&config.check_mode,
true,
)?;
// enable the selectors
@@ -2440,7 +2385,6 @@ pub(crate) fn pairwise<F: PrimeField + TensorType + PartialOrd + std::hash::Hash
let orig_lhs = lhs.clone();
let orig_rhs = rhs.clone();
let start = instant::Instant::now();
let first_zero_indices = HashSet::from_iter(lhs.get_const_zero_indices());
let second_zero_indices = HashSet::from_iter(rhs.get_const_zero_indices());
@@ -2455,7 +2399,6 @@ pub(crate) fn pairwise<F: PrimeField + TensorType + PartialOrd + std::hash::Hash
BaseOp::Sub => second_zero_indices.clone(),
_ => return Err(CircuitError::UnsupportedOp),
};
trace!("setting up indices took {:?}", start.elapsed());
if lhs.len() != rhs.len() {
return Err(CircuitError::DimMismatch(format!(
@@ -2480,7 +2423,6 @@ pub(crate) fn pairwise<F: PrimeField + TensorType + PartialOrd + std::hash::Hash
// Now we can assign the dot product
// time the calc
let start = instant::Instant::now();
let op_result = match op {
BaseOp::Add => add(&inputs),
BaseOp::Sub => sub(&inputs),
@@ -2491,20 +2433,13 @@ pub(crate) fn pairwise<F: PrimeField + TensorType + PartialOrd + std::hash::Hash
error!("{}", e);
halo2_proofs::plonk::Error::Synthesis
})?;
trace!("pairwise {} calc took {:?}", op.as_str(), start.elapsed());
let start = instant::Instant::now();
let assigned_len = op_result.len() - removal_indices.len();
let mut output = region.assign_with_omissions(
&config.custom_gates.output,
&op_result.into(),
&removal_indices,
)?;
trace!(
"pairwise {} input assign took {:?}",
op.as_str(),
start.elapsed()
);
// Enable the selectors
if !region.is_dummy() {
@@ -2671,9 +2606,7 @@ pub fn greater<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
rhs.expand(&broadcasted_shape)?;
let diff = pairwise(config, region, &[lhs, rhs], BaseOp::Sub)?;
let sign = sign(config, region, &[diff])?;
equals(config, region, &[sign, create_unit_tensor(1)])
}
@@ -5286,75 +5219,72 @@ pub(crate) fn decompose<F: PrimeField + TensorType + PartialOrd + std::hash::Has
base: &usize,
n: &usize,
) -> Result<ValTensor<F>, CircuitError> {
let input = values[0].clone();
let mut input = values[0].clone();
let is_assigned = !input.all_prev_assigned();
let bases: ValTensor<F> = Tensor::from(
(0..*n)
.rev()
.map(|x| ValType::Constant(integer_rep_to_felt(base.pow(x as u32) as IntegerRep))),
if !is_assigned {
input = region.assign(&config.custom_gates.inputs[0], &input)?;
}
let mut bases: ValTensor<F> = Tensor::from(
// repeat it input.len() times
(0..input.len()).flat_map(|_| {
(0..*n)
.rev()
.map(|x| ValType::Constant(integer_rep_to_felt(base.pow(x as u32) as IntegerRep)))
}),
)
.into();
let mut bases_dims = input.dims().to_vec();
bases_dims.push(*n);
bases.reshape(&bases_dims)?;
let cartesian_coord = input
.dims()
.iter()
.map(|x| 0..*x)
.multi_cartesian_product()
.collect::<Vec<_>>();
let mut decomposed_dims = input.dims().to_vec();
decomposed_dims.push(*n + 1);
let mut output: Tensor<Tensor<ValType<F>>> = Tensor::new(None, input.dims())?;
let claimed_output = if region.witness_gen() {
input.decompose(*base, *n)?
} else {
let decomposed_len = decomposed_dims.iter().product();
let claimed_output = Tensor::new(
Some(&vec![ValType::Value(Value::unknown()); decomposed_len]),
&decomposed_dims,
)?;
let inner_loop_function =
|i: usize, region: &mut RegionCtx<F>| -> Result<Tensor<ValType<F>>, CircuitError> {
let coord = cartesian_coord[i].clone();
let slice = coord.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
let mut sliced_input = input.get_slice(&slice)?;
sliced_input.flatten();
claimed_output.into()
};
region.assign(&config.custom_gates.output, &claimed_output)?;
region.increment(claimed_output.len());
if !is_assigned {
sliced_input = region.assign(&config.custom_gates.inputs[0], &sliced_input)?;
}
let input_slice = input.dims().iter().map(|x| 0..*x).collect::<Vec<_>>();
let mut sign_slice = input_slice.clone();
sign_slice.push(0..1);
let mut rest_slice = input_slice.clone();
rest_slice.push(1..n + 1);
let mut claimed_output_slice = if region.witness_gen() {
sliced_input.decompose(*base, *n)?
} else {
Tensor::from(vec![ValType::Value(Value::unknown()); *n + 1].into_iter()).into()
};
let sign = claimed_output.get_slice(&sign_slice)?;
let rest = claimed_output.get_slice(&rest_slice)?;
claimed_output_slice =
region.assign(&config.custom_gates.inputs[1], &claimed_output_slice)?;
claimed_output_slice.flatten();
let sign = range_check(config, region, &[sign], &(-1, 1))?;
let rest = range_check(config, region, &[rest], &(0, (*base - 1) as i128))?;
region.increment(claimed_output_slice.len());
// equation needs to be constructed as ij,ij->i but for arbitrary n dims we need to construct this dynamically
// indices should map in order of the alphabet
// start with lhs
let lhs = ASCII_ALPHABET.chars().take(rest.dims().len()).join("");
let rhs = ASCII_ALPHABET.chars().take(rest.dims().len() - 1).join("");
let equation = format!("{},{}->{}", lhs, lhs, rhs);
// get the sign bit and make sure it is valid
let sign = claimed_output_slice.first()?;
let sign = range_check(config, region, &[sign], &(-1, 1))?;
// now add the rhs
// get the rest of the thing and make sure it is in the correct range
let rest = claimed_output_slice.get_slice(&[1..claimed_output_slice.len()])?;
let prod_decomp = einsum(config, region, &[rest.clone(), bases], &equation)?;
let rest = range_check(config, region, &[rest], &(0, (base - 1) as i128))?;
let signed_decomp = pairwise(config, region, &[prod_decomp, sign], BaseOp::Mult)?;
let prod_decomp = dot(config, region, &[rest, bases.clone()])?;
enforce_equality(config, region, &[input, signed_decomp])?;
let signed_decomp = pairwise(config, region, &[prod_decomp, sign], BaseOp::Mult)?;
enforce_equality(config, region, &[sliced_input, signed_decomp])?;
Ok(claimed_output_slice.get_inner_tensor()?.clone())
};
region.apply_in_loop(&mut output, inner_loop_function)?;
let mut combined_output = output.combine()?;
let mut output_dims = input.dims().to_vec();
output_dims.push(*n + 1);
combined_output.reshape(&output_dims)?;
Ok(combined_output.into())
Ok(claimed_output)
}
pub(crate) fn sign<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(

View File

@@ -671,22 +671,17 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
}
/// Assign a valtensor to a vartensor with duplication
pub fn assign_with_duplication(
pub fn assign_with_duplication_unconstrained(
&mut self,
var: &VarTensor,
values: &ValTensor<F>,
check_mode: &crate::circuit::CheckMode,
single_inner_col: bool,
) -> Result<(ValTensor<F>, usize), Error> {
if let Some(region) = &self.region {
// duplicates every nth element to adjust for column overflow
let (res, len) = var.assign_with_duplication(
let (res, len) = var.assign_with_duplication_unconstrained(
&mut region.borrow_mut(),
self.row,
self.linear_coord,
values,
check_mode,
single_inner_col,
&mut self.assigned_constants,
)?;
Ok((res, len))
@@ -695,7 +690,37 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
self.row,
self.linear_coord,
values,
single_inner_col,
false,
&mut self.assigned_constants,
)?;
Ok((values.clone(), len))
}
}
/// Assign a valtensor to a vartensor with duplication
pub fn assign_with_duplication_constrained(
&mut self,
var: &VarTensor,
values: &ValTensor<F>,
check_mode: &crate::circuit::CheckMode,
) -> Result<(ValTensor<F>, usize), Error> {
if let Some(region) = &self.region {
// duplicates every nth element to adjust for column overflow
let (res, len) = var.assign_with_duplication_constrained(
&mut region.borrow_mut(),
self.row,
self.linear_coord,
values,
check_mode,
&mut self.assigned_constants,
)?;
Ok((res, len))
} else {
let (_, len) = var.dummy_assign_with_duplication(
self.row,
self.linear_coord,
values,
true,
&mut self.assigned_constants,
)?;
Ok((values.clone(), len))

View File

@@ -488,7 +488,6 @@ pub async fn deploy_da_verifier_via_solidity(
}
}
match call_to_account {
Some(call) => {
deploy_single_da_contract(

View File

@@ -280,7 +280,13 @@ impl GraphWitness {
})?;
let reader = std::io::BufReader::with_capacity(*EZKL_BUF_CAPACITY, file);
serde_json::from_reader(reader).map_err(|e| e.into())
let witness: GraphWitness =
serde_json::from_reader(reader).map_err(|e| Into::<GraphError>::into(e))?;
// check versions match
crate::check_version_string_matches(witness.version.as_deref().unwrap_or(""));
Ok(witness)
}
/// Save the model input to a file
@@ -572,10 +578,14 @@ impl GraphSettings {
// buf reader
let reader =
std::io::BufReader::with_capacity(*EZKL_BUF_CAPACITY, std::fs::File::open(path)?);
serde_json::from_reader(reader).map_err(|e| {
let settings: GraphSettings = serde_json::from_reader(reader).map_err(|e| {
error!("failed to load settings file at {}", e);
std::io::Error::new(std::io::ErrorKind::Other, e)
})
})?;
crate::check_version_string_matches(&settings.version);
Ok(settings)
}
/// Export the ezkl configuration as json
@@ -697,6 +707,9 @@ impl GraphCircuit {
let reader = std::io::BufReader::with_capacity(*EZKL_BUF_CAPACITY, f);
let result: GraphCircuit = bincode::deserialize_from(reader)?;
// check the versions matche
crate::check_version_string_matches(&result.core.settings.version);
Ok(result)
}
}

View File

@@ -1226,6 +1226,7 @@ impl Model {
values.iter().map(|v| v.dims()).collect_vec()
);
let start = instant::Instant::now();
match &node {
NodeType::Node(n) => {
let res = if node.is_constant() && node.num_uses() == 1 {
@@ -1363,6 +1364,7 @@ impl Model {
results.insert(*idx, full_results);
}
}
debug!("------------ layout of {} took {:?}", idx, start.elapsed());
}
// we do this so we can support multiple passes of the same model and have deterministic results (Non-assigned inputs etc... etc...)

View File

@@ -420,3 +420,30 @@ where
let b = s[pos + 2..].parse()?;
Ok((a, b))
}
/// Check if the version string matches the artifact version
/// If the version string does not match the artifact version, log a warning
pub fn check_version_string_matches(artifact_version: &str) {
if artifact_version == "0.0.0"
|| artifact_version == "source - no compatibility guaranteed"
|| artifact_version.is_empty()
{
log::warn!("Artifact version is 0.0.0, skipping version check");
return;
}
let version = crate::version();
if version == "source - no compatibility guaranteed" {
log::warn!("Compiled source version is not guaranteed to match artifact version");
return;
}
if version != artifact_version {
log::warn!(
"Version mismatch: CLI version is {} but artifact version is {}",
version,
artifact_version
);
}
}

View File

@@ -822,6 +822,7 @@ where
Scheme::Scalar: PrimeField + SerdeObject + FromUniformBytes<64>,
{
debug!("loading proving key from {:?}", path);
let start = instant::Instant::now();
let f = File::open(path.clone()).map_err(|e| PfsysError::LoadPk(format!("{}", e)))?;
let mut reader = BufReader::with_capacity(*EZKL_BUF_CAPACITY, f);
let pk = ProvingKey::<Scheme::Curve>::read::<_, C>(
@@ -830,7 +831,8 @@ where
params,
)
.map_err(|e| PfsysError::LoadPk(format!("{}", e)))?;
info!("loaded proving key ✅");
let elapsed = start.elapsed();
info!("loaded proving key in {:?}", elapsed);
Ok(pk)
}

View File

@@ -833,7 +833,7 @@ impl<T: Clone + TensorType> Tensor<T> {
num_repeats: usize,
initial_offset: usize,
) -> Result<Tensor<T>, TensorError> {
let mut inner: Vec<T> = vec![];
let mut inner: Vec<T> = Vec::with_capacity(self.inner.len());
let mut offset = initial_offset;
for (i, elem) in self.inner.clone().into_iter().enumerate() {
if (i + offset + 1) % n == 0 {
@@ -862,20 +862,22 @@ impl<T: Clone + TensorType> Tensor<T> {
num_repeats: usize,
initial_offset: usize,
) -> Result<Tensor<T>, TensorError> {
let mut inner: Vec<T> = vec![];
let mut indices_to_remove = std::collections::HashSet::new();
for i in 0..self.inner.len() {
if (i + initial_offset + 1) % n == 0 {
for j in 1..(1 + num_repeats) {
indices_to_remove.insert(i + j);
}
}
}
// Pre-calculate capacity to avoid reallocations
let estimated_size = self.inner.len() - (self.inner.len() / n) * num_repeats;
let mut inner = Vec::with_capacity(estimated_size);
let old_inner = self.inner.clone();
for (i, elem) in old_inner.into_iter().enumerate() {
if !indices_to_remove.contains(&i) {
inner.push(elem.clone());
// Use iterator directly instead of creating intermediate collections
let mut i = 0;
while i < self.inner.len() {
// Add the current element
inner.push(self.inner[i].clone());
// If this is an nth position (accounting for offset)
if (i + initial_offset + 1) % n == 0 {
// Skip the next num_repeats elements
i += num_repeats + 1;
} else {
i += 1;
}
}

View File

@@ -494,16 +494,56 @@ impl VarTensor {
}
}
/// Assigns specific values (`ValTensor`) to the columns of the inner tensor but allows for column wrapping for accumulated operations.
pub fn assign_with_duplication_unconstrained<
F: PrimeField + TensorType + PartialOrd + std::hash::Hash,
>(
&self,
region: &mut Region<F>,
offset: usize,
values: &ValTensor<F>,
constants: &mut ConstantsMap<F>,
) -> Result<(ValTensor<F>, usize), halo2_proofs::plonk::Error> {
match values {
ValTensor::Instance { .. } => unimplemented!("duplication is not supported on instance columns. increase K if you require more rows."),
ValTensor::Value { inner: v, dims , ..} => {
let duplication_freq = self.block_size();
let num_repeats = self.num_inner_cols();
let duplication_offset = offset;
// duplicates every nth element to adjust for column overflow
let v = v.duplicate_every_n(duplication_freq, num_repeats, duplication_offset).unwrap();
let mut res: ValTensor<F> = {
v.enum_map(|coord, k| {
let cell = self.assign_value(region, offset, k.clone(), coord, constants)?;
Ok::<_, halo2_proofs::plonk::Error>(cell)
})?.into()};
let total_used_len = res.len();
res.remove_every_n(duplication_freq, num_repeats, duplication_offset).unwrap();
res.reshape(dims).unwrap();
res.set_scale(values.scale());
Ok((res, total_used_len))
}
}
}
/// Assigns specific values (`ValTensor`) to the columns of the inner tensor but allows for column wrapping for accumulated operations.
/// Duplication occurs by copying the last cell of the column to the first cell next column and creating a copy constraint between the two.
pub fn assign_with_duplication<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
pub fn assign_with_duplication_constrained<
F: PrimeField + TensorType + PartialOrd + std::hash::Hash,
>(
&self,
region: &mut Region<F>,
row: usize,
offset: usize,
values: &ValTensor<F>,
check_mode: &CheckMode,
single_inner_col: bool,
constants: &mut ConstantsMap<F>,
) -> Result<(ValTensor<F>, usize), halo2_proofs::plonk::Error> {
let mut prev_cell = None;
@@ -512,34 +552,16 @@ impl VarTensor {
ValTensor::Instance { .. } => unimplemented!("duplication is not supported on instance columns. increase K if you require more rows."),
ValTensor::Value { inner: v, dims , ..} => {
let duplication_freq = if single_inner_col {
self.col_size()
} else {
self.block_size()
};
let num_repeats = if single_inner_col {
1
} else {
self.num_inner_cols()
};
let duplication_offset = if single_inner_col {
row
} else {
offset
};
let duplication_freq = self.col_size();
let num_repeats = 1;
let duplication_offset = row;
// duplicates every nth element to adjust for column overflow
let v = v.duplicate_every_n(duplication_freq, num_repeats, duplication_offset).unwrap();
let mut res: ValTensor<F> = {
v.enum_map(|coord, k| {
let step = if !single_inner_col {
1
} else {
self.num_inner_cols()
};
let step = self.num_inner_cols();
let (x, y, z) = self.cartesian_coord(offset + coord * step);
if matches!(check_mode, CheckMode::SAFE) && coord > 0 && z == 0 && y == 0 {
@@ -549,11 +571,13 @@ impl VarTensor {
let cell = self.assign_value(region, offset, k.clone(), coord * step, constants)?;
if single_inner_col {
if z == 0 {
let at_end_of_column = z == duplication_freq - 1;
let at_beginning_of_column = z == 0;
if at_end_of_column {
// if we are at the end of the column, we need to copy the cell to the next column
prev_cell = Some(cell.clone());
} else if coord > 0 && z == 0 && single_inner_col {
} else if coord > 0 && at_beginning_of_column {
if let Some(prev_cell) = prev_cell.as_ref() {
let cell = cell.cell().ok_or({
error!("Error getting cell: {:?}", (x,y));
@@ -563,10 +587,10 @@ impl VarTensor {
halo2_proofs::plonk::Error::Synthesis})?;
region.constrain_equal(prev_cell,cell)?;
} else {
error!("Error copy-constraining previous value: {:?}", (x,y));
error!("Previous cell was not set");
return Err(halo2_proofs::plonk::Error::Synthesis);
}
}}
}
Ok(cell)
@@ -577,20 +601,6 @@ impl VarTensor {
res.reshape(dims).unwrap();
res.set_scale(values.scale());
if matches!(check_mode, CheckMode::SAFE) {
// during key generation this will be 0 so we use this as a flag to check
// TODO: this isn't very safe and would be better to get the phase directly
let res_evals = res.int_evals().unwrap();
let is_assigned = res_evals
.iter()
.all(|&x| x == 0);
if !is_assigned {
assert_eq!(
values.int_evals().unwrap(),
res_evals
)};
}
Ok((res, total_used_len))
}
}

Binary file not shown.

Binary file not shown.