Compare commits

...

1 Commits

Author SHA1 Message Date
dante
12735aefd4 chore: reduce softmax recip DR (#756) 2024-03-27 01:14:29 +00:00
4 changed files with 54 additions and 25 deletions

View File

@@ -568,10 +568,10 @@ fn _sort_ascending<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
let is_assigned = !input.any_unknowns()?;
let sorted = if is_assigned {
input
.get_int_evals()?
.iter()
.sorted_by(|a, b| a.cmp(b))
let mut int_evals = input.get_int_evals()?;
int_evals.par_sort_unstable_by(|a, b| a.cmp(b));
int_evals
.par_iter()
.map(|x| Value::known(i128_to_felt(*x)))
.collect::<Tensor<Value<F>>>()
} else {
@@ -753,20 +753,28 @@ pub(crate) fn dynamic_lookup<F: PrimeField + TensorType + PartialOrd + std::hash
let _table_1 = region.assign_dynamic_lookup(&config.dynamic_lookups.tables[1], &table_1)?;
let table_len = table_0.len();
trace!("assigning tables took: {:?}", start.elapsed());
// now create a vartensor of constants for the dynamic lookup index
let table_index = create_constant_tensor(F::from(dynamic_lookup_index as u64), table_len);
let _table_index =
region.assign_dynamic_lookup(&config.dynamic_lookups.tables[2], &table_index)?;
trace!("assigning table index took: {:?}", start.elapsed());
let lookup_0 = region.assign(&config.dynamic_lookups.inputs[0], &lookup_0)?;
let lookup_1 = region.assign(&config.dynamic_lookups.inputs[1], &lookup_1)?;
let lookup_len = lookup_0.len();
trace!("assigning lookups took: {:?}", start.elapsed());
// now set the lookup index
let lookup_index = create_constant_tensor(F::from(dynamic_lookup_index as u64), lookup_len);
let _lookup_index = region.assign(&config.dynamic_lookups.inputs[2], &lookup_index)?;
trace!("assigning lookup index took: {:?}", start.elapsed());
if !region.is_dummy() {
(0..table_len)
.map(|i| {
@@ -3251,11 +3259,15 @@ pub(crate) fn softmax<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>
input_scale: utils::F32,
output_scale: utils::F32,
) -> Result<ValTensor<F>, Box<dyn Error>> {
// get the max then subtract it
let max_val = max(config, region, values)?;
// rebase the input to 0
let sub = pairwise(config, region, &[values[0].clone(), max_val], BaseOp::Sub)?;
// elementwise exponential
let ex = nonlinearity(
config,
region,
values,
&[sub],
&LookupOp::Exp { scale: input_scale },
)?;

View File

@@ -163,7 +163,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
///
pub fn update_constants(&mut self, constants: ConstantsMap<F>) {
self.assigned_constants.extend(constants.into_iter());
self.assigned_constants.extend(constants);
}
///
@@ -389,7 +389,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
shuffle_index.update(&local_reg.shuffle_index);
// update the constants
let mut constants = constants.lock().unwrap();
constants.extend(local_reg.assigned_constants.into_iter());
constants.extend(local_reg.assigned_constants);
res
})
@@ -574,8 +574,10 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
&mut self.assigned_constants,
)
} else {
let values_map = values.create_constants_map();
self.assigned_constants.extend(values_map);
if !values.is_instance() {
let values_map = values.create_constants_map_iterator();
self.assigned_constants.extend(values_map);
}
Ok(values.clone())
}
}
@@ -599,8 +601,10 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
&mut self.assigned_constants,
)
} else {
let values_map = values.create_constants_map();
self.assigned_constants.extend(values_map);
if !values.is_instance() {
let values_map = values.create_constants_map_iterator();
self.assigned_constants.extend(values_map);
}
Ok(values.clone())
}
}
@@ -630,9 +634,8 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
&mut self.assigned_constants,
)
} else {
let mut values_map = values.create_constants_map();
let inner_tensor = values.get_inner_tensor().unwrap();
let mut values_map = values.create_constants_map();
for o in ommissions {
if let ValType::Constant(value) = inner_tensor.get_flat_index(**o) {

View File

@@ -911,7 +911,7 @@ pub(crate) fn calibrate(
let model = Model::from_run_args(&settings.run_args, &model_path)?;
let chunks = data.split_into_batches(model.graph.input_shapes()?)?;
debug!("num of calibration batches: {}", chunks.len());
info!("num calibration batches: {}", chunks.len());
debug!("running onnx predictions...");
let original_predictions = Model::run_onnx_predictions(

View File

@@ -448,25 +448,39 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
}
/// Returns the number of constants in the [ValTensor].
pub fn num_constants(&self) -> usize {
pub fn create_constants_map_iterator(
&self,
) -> core::iter::FilterMap<
core::slice::Iter<'_, ValType<F>>,
fn(&ValType<F>) -> Option<(F, ValType<F>)>,
> {
match self {
ValTensor::Value { inner, .. } => inner.iter().filter(|x| x.is_constant()).count(),
ValTensor::Instance { .. } => 0,
ValTensor::Value { inner, .. } => inner.iter().filter_map(|x| {
if let ValType::Constant(v) = x {
Some((*v, x.clone()))
} else {
None
}
}),
ValTensor::Instance { .. } => {
unreachable!("Instance tensors do not have constants")
}
}
}
/// Returns the number of constants in the [ValTensor].
pub fn create_constants_map(&self) -> ConstantsMap<F> {
match self {
ValTensor::Value { inner, .. } => {
let map = inner.iter().fold(ConstantsMap::new(), |mut acc, x| {
if let ValType::Constant(c) = x {
acc.insert(*c, x.clone());
ValTensor::Value { inner, .. } => inner
.par_iter()
.filter_map(|x| {
if let ValType::Constant(v) = x {
Some((*v, x.clone()))
} else {
None
}
acc
});
map
}
})
.collect(),
ValTensor::Instance { .. } => ConstantsMap::new(),
}
}