mirror of
https://github.com/zkonduit/ezkl.git
synced 2026-01-13 00:08:12 -05:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
12735aefd4 |
@@ -568,10 +568,10 @@ fn _sort_ascending<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
let is_assigned = !input.any_unknowns()?;
|
||||
|
||||
let sorted = if is_assigned {
|
||||
input
|
||||
.get_int_evals()?
|
||||
.iter()
|
||||
.sorted_by(|a, b| a.cmp(b))
|
||||
let mut int_evals = input.get_int_evals()?;
|
||||
int_evals.par_sort_unstable_by(|a, b| a.cmp(b));
|
||||
int_evals
|
||||
.par_iter()
|
||||
.map(|x| Value::known(i128_to_felt(*x)))
|
||||
.collect::<Tensor<Value<F>>>()
|
||||
} else {
|
||||
@@ -753,20 +753,28 @@ pub(crate) fn dynamic_lookup<F: PrimeField + TensorType + PartialOrd + std::hash
|
||||
let _table_1 = region.assign_dynamic_lookup(&config.dynamic_lookups.tables[1], &table_1)?;
|
||||
let table_len = table_0.len();
|
||||
|
||||
trace!("assigning tables took: {:?}", start.elapsed());
|
||||
|
||||
// now create a vartensor of constants for the dynamic lookup index
|
||||
let table_index = create_constant_tensor(F::from(dynamic_lookup_index as u64), table_len);
|
||||
let _table_index =
|
||||
region.assign_dynamic_lookup(&config.dynamic_lookups.tables[2], &table_index)?;
|
||||
|
||||
trace!("assigning table index took: {:?}", start.elapsed());
|
||||
|
||||
let lookup_0 = region.assign(&config.dynamic_lookups.inputs[0], &lookup_0)?;
|
||||
let lookup_1 = region.assign(&config.dynamic_lookups.inputs[1], &lookup_1)?;
|
||||
let lookup_len = lookup_0.len();
|
||||
|
||||
trace!("assigning lookups took: {:?}", start.elapsed());
|
||||
|
||||
// now set the lookup index
|
||||
let lookup_index = create_constant_tensor(F::from(dynamic_lookup_index as u64), lookup_len);
|
||||
|
||||
let _lookup_index = region.assign(&config.dynamic_lookups.inputs[2], &lookup_index)?;
|
||||
|
||||
trace!("assigning lookup index took: {:?}", start.elapsed());
|
||||
|
||||
if !region.is_dummy() {
|
||||
(0..table_len)
|
||||
.map(|i| {
|
||||
@@ -3251,11 +3259,15 @@ pub(crate) fn softmax<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>
|
||||
input_scale: utils::F32,
|
||||
output_scale: utils::F32,
|
||||
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
// get the max then subtract it
|
||||
let max_val = max(config, region, values)?;
|
||||
// rebase the input to 0
|
||||
let sub = pairwise(config, region, &[values[0].clone(), max_val], BaseOp::Sub)?;
|
||||
// elementwise exponential
|
||||
let ex = nonlinearity(
|
||||
config,
|
||||
region,
|
||||
values,
|
||||
&[sub],
|
||||
&LookupOp::Exp { scale: input_scale },
|
||||
)?;
|
||||
|
||||
|
||||
@@ -163,7 +163,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
|
||||
///
|
||||
pub fn update_constants(&mut self, constants: ConstantsMap<F>) {
|
||||
self.assigned_constants.extend(constants.into_iter());
|
||||
self.assigned_constants.extend(constants);
|
||||
}
|
||||
|
||||
///
|
||||
@@ -389,7 +389,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
shuffle_index.update(&local_reg.shuffle_index);
|
||||
// update the constants
|
||||
let mut constants = constants.lock().unwrap();
|
||||
constants.extend(local_reg.assigned_constants.into_iter());
|
||||
constants.extend(local_reg.assigned_constants);
|
||||
|
||||
res
|
||||
})
|
||||
@@ -574,8 +574,10 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
&mut self.assigned_constants,
|
||||
)
|
||||
} else {
|
||||
let values_map = values.create_constants_map();
|
||||
self.assigned_constants.extend(values_map);
|
||||
if !values.is_instance() {
|
||||
let values_map = values.create_constants_map_iterator();
|
||||
self.assigned_constants.extend(values_map);
|
||||
}
|
||||
Ok(values.clone())
|
||||
}
|
||||
}
|
||||
@@ -599,8 +601,10 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
&mut self.assigned_constants,
|
||||
)
|
||||
} else {
|
||||
let values_map = values.create_constants_map();
|
||||
self.assigned_constants.extend(values_map);
|
||||
if !values.is_instance() {
|
||||
let values_map = values.create_constants_map_iterator();
|
||||
self.assigned_constants.extend(values_map);
|
||||
}
|
||||
Ok(values.clone())
|
||||
}
|
||||
}
|
||||
@@ -630,9 +634,8 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
&mut self.assigned_constants,
|
||||
)
|
||||
} else {
|
||||
let mut values_map = values.create_constants_map();
|
||||
|
||||
let inner_tensor = values.get_inner_tensor().unwrap();
|
||||
let mut values_map = values.create_constants_map();
|
||||
|
||||
for o in ommissions {
|
||||
if let ValType::Constant(value) = inner_tensor.get_flat_index(**o) {
|
||||
|
||||
@@ -911,7 +911,7 @@ pub(crate) fn calibrate(
|
||||
let model = Model::from_run_args(&settings.run_args, &model_path)?;
|
||||
|
||||
let chunks = data.split_into_batches(model.graph.input_shapes()?)?;
|
||||
debug!("num of calibration batches: {}", chunks.len());
|
||||
info!("num calibration batches: {}", chunks.len());
|
||||
|
||||
debug!("running onnx predictions...");
|
||||
let original_predictions = Model::run_onnx_predictions(
|
||||
|
||||
@@ -448,25 +448,39 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
|
||||
}
|
||||
|
||||
/// Returns the number of constants in the [ValTensor].
|
||||
pub fn num_constants(&self) -> usize {
|
||||
pub fn create_constants_map_iterator(
|
||||
&self,
|
||||
) -> core::iter::FilterMap<
|
||||
core::slice::Iter<'_, ValType<F>>,
|
||||
fn(&ValType<F>) -> Option<(F, ValType<F>)>,
|
||||
> {
|
||||
match self {
|
||||
ValTensor::Value { inner, .. } => inner.iter().filter(|x| x.is_constant()).count(),
|
||||
ValTensor::Instance { .. } => 0,
|
||||
ValTensor::Value { inner, .. } => inner.iter().filter_map(|x| {
|
||||
if let ValType::Constant(v) = x {
|
||||
Some((*v, x.clone()))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}),
|
||||
ValTensor::Instance { .. } => {
|
||||
unreachable!("Instance tensors do not have constants")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the number of constants in the [ValTensor].
|
||||
pub fn create_constants_map(&self) -> ConstantsMap<F> {
|
||||
match self {
|
||||
ValTensor::Value { inner, .. } => {
|
||||
let map = inner.iter().fold(ConstantsMap::new(), |mut acc, x| {
|
||||
if let ValType::Constant(c) = x {
|
||||
acc.insert(*c, x.clone());
|
||||
ValTensor::Value { inner, .. } => inner
|
||||
.par_iter()
|
||||
.filter_map(|x| {
|
||||
if let ValType::Constant(v) = x {
|
||||
Some((*v, x.clone()))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
acc
|
||||
});
|
||||
map
|
||||
}
|
||||
})
|
||||
.collect(),
|
||||
ValTensor::Instance { .. } => ConstantsMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user