Compare commits

..

1 Commits

Author SHA1 Message Date
dante
4c8daf773c refactor: lookup-less layer norm (#706) 2024-02-07 21:19:17 +00:00

View File

@@ -1877,13 +1877,11 @@ pub fn sumpool<F: PrimeField + TensorType + PartialOrd>(
last_elem.reshape(&[&[batch_size, image_channels], shape].concat())?;
if normalized {
last_elem = nonlinearity(
last_elem = div(
config,
region,
&[last_elem],
&LookupOp::Div {
denom: utils::F32((kernel_shape.0 * kernel_shape.1) as f32),
},
F::from((kernel_shape.0 * kernel_shape.1) as u64),
)?;
}
Ok(last_elem)
@@ -2619,22 +2617,6 @@ pub fn nonlinearity<F: PrimeField + TensorType + PartialOrd>(
Ok(output)
}
/// mean function layout
pub fn mean<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
scale: usize,
) -> Result<ValTensor<F>, Box<dyn Error>> {
let x = &values[0];
let sum_x = sum(config, region, &[x.clone()])?;
let nl = LookupOp::Div {
denom: utils::F32((scale * x.len()) as f32),
};
nonlinearity(config, region, &[sum_x], &nl)
}
/// Argmax
pub fn argmax<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,