mirror of
https://github.com/gfx-rs/wgpu.git
synced 2026-04-22 03:02:01 -04:00
[const-eval] implement pow & clamp built-in functions properly
This commit is contained in:
@@ -128,6 +128,8 @@ pub enum ConstantEvaluatorError {
|
||||
InvalidMathArg,
|
||||
#[error("{0:?} built-in function expects {1:?} arguments but {2:?} were supplied")]
|
||||
InvalidMathArgCount(crate::MathFunction, usize, usize),
|
||||
#[error("value of `low` is greater than `high` for clamp built-in function")]
|
||||
InvalidClamp,
|
||||
#[error("Splat is defined only on scalar values")]
|
||||
SplatScalarOnly,
|
||||
#[error("Can only swizzle vector constants")]
|
||||
@@ -501,62 +503,183 @@ impl<'a> ConstantEvaluator<'a> {
|
||||
));
|
||||
}
|
||||
|
||||
let const0 = &self.expressions[arg];
|
||||
let const1 = arg1.map(|arg| &self.expressions[arg]);
|
||||
let const2 = arg2.map(|arg| &self.expressions[arg]);
|
||||
let _const3 = arg3.map(|arg| &self.expressions[arg]);
|
||||
|
||||
match fun {
|
||||
crate::MathFunction::Pow => {
|
||||
let literal = match (const0, const1.unwrap()) {
|
||||
(&Expression::Literal(value0), &Expression::Literal(value1)) => {
|
||||
match (value0, value1) {
|
||||
(Literal::I32(a), Literal::I32(b)) => Literal::I32(a.pow(b as u32)),
|
||||
(Literal::U32(a), Literal::U32(b)) => Literal::U32(a.pow(b)),
|
||||
(Literal::F32(a), Literal::F32(b)) => Literal::F32(a.powf(b)),
|
||||
_ => return Err(ConstantEvaluatorError::InvalidMathArg),
|
||||
}
|
||||
}
|
||||
_ => return Err(ConstantEvaluatorError::InvalidMathArg),
|
||||
};
|
||||
|
||||
let expr = Expression::Literal(literal);
|
||||
Ok(self.register_evaluated_expr(expr, span))
|
||||
}
|
||||
crate::MathFunction::Clamp => {
|
||||
let literal = match (const0, const1.unwrap(), const2.unwrap()) {
|
||||
(
|
||||
&Expression::Literal(value0),
|
||||
&Expression::Literal(value1),
|
||||
&Expression::Literal(value2),
|
||||
) => match (value0, value1, value2) {
|
||||
(Literal::I32(a), Literal::I32(b), Literal::I32(c)) => {
|
||||
Literal::I32(a.clamp(b, c))
|
||||
}
|
||||
(Literal::U32(a), Literal::U32(b), Literal::U32(c)) => {
|
||||
Literal::U32(a.clamp(b, c))
|
||||
}
|
||||
(Literal::F32(a), Literal::F32(b), Literal::F32(c)) => {
|
||||
Literal::F32(glsl_float_clamp(a, b, c))
|
||||
}
|
||||
_ => return Err(ConstantEvaluatorError::InvalidMathArg),
|
||||
},
|
||||
_ => {
|
||||
return Err(ConstantEvaluatorError::NotImplemented(
|
||||
"clamp built-in function with vector values".into(),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let expr = Expression::Literal(literal);
|
||||
Ok(self.register_evaluated_expr(expr, span))
|
||||
}
|
||||
crate::MathFunction::Pow => self.math_pow(arg, arg1.unwrap(), span),
|
||||
crate::MathFunction::Clamp => self.math_clamp(arg, arg1.unwrap(), arg2.unwrap(), span),
|
||||
fun => Err(ConstantEvaluatorError::NotImplemented(format!(
|
||||
"{fun:?} built-in function"
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
fn math_pow(
|
||||
&mut self,
|
||||
e1: Handle<Expression>,
|
||||
e2: Handle<Expression>,
|
||||
span: Span,
|
||||
) -> Result<Handle<Expression>, ConstantEvaluatorError> {
|
||||
let e1 = self.eval_zero_value_and_splat(e1, span)?;
|
||||
let e2 = self.eval_zero_value_and_splat(e2, span)?;
|
||||
|
||||
let expr = match (&self.expressions[e1], &self.expressions[e2]) {
|
||||
(&Expression::Literal(Literal::F32(a)), &Expression::Literal(Literal::F32(b))) => {
|
||||
Expression::Literal(Literal::F32(a.powf(b)))
|
||||
}
|
||||
(
|
||||
&Expression::Compose {
|
||||
components: ref src_components0,
|
||||
ty: ty0,
|
||||
},
|
||||
&Expression::Compose {
|
||||
components: ref src_components1,
|
||||
ty: ty1,
|
||||
},
|
||||
) if ty0 == ty1
|
||||
&& matches!(
|
||||
self.types[ty0].inner,
|
||||
crate::TypeInner::Vector {
|
||||
kind: crate::ScalarKind::Float,
|
||||
..
|
||||
}
|
||||
) =>
|
||||
{
|
||||
let mut components: Vec<_> = crate::proc::flatten_compose(
|
||||
ty0,
|
||||
src_components0,
|
||||
self.expressions,
|
||||
self.types,
|
||||
)
|
||||
.chain(crate::proc::flatten_compose(
|
||||
ty1,
|
||||
src_components1,
|
||||
self.expressions,
|
||||
self.types,
|
||||
))
|
||||
.collect();
|
||||
|
||||
let mid = components.len() / 2;
|
||||
let (first, last) = components.split_at_mut(mid);
|
||||
for (a, b) in first.iter_mut().zip(&*last) {
|
||||
*a = self.math_pow(*a, *b, span)?;
|
||||
}
|
||||
components.truncate(mid);
|
||||
|
||||
Expression::Compose {
|
||||
ty: ty0,
|
||||
components,
|
||||
}
|
||||
}
|
||||
_ => return Err(ConstantEvaluatorError::InvalidMathArg),
|
||||
};
|
||||
|
||||
Ok(self.register_evaluated_expr(expr, span))
|
||||
}
|
||||
|
||||
fn math_clamp(
|
||||
&mut self,
|
||||
e: Handle<Expression>,
|
||||
low: Handle<Expression>,
|
||||
high: Handle<Expression>,
|
||||
span: Span,
|
||||
) -> Result<Handle<Expression>, ConstantEvaluatorError> {
|
||||
let e = self.eval_zero_value_and_splat(e, span)?;
|
||||
let low = self.eval_zero_value_and_splat(low, span)?;
|
||||
let high = self.eval_zero_value_and_splat(high, span)?;
|
||||
|
||||
let expr = match (
|
||||
&self.expressions[e],
|
||||
&self.expressions[low],
|
||||
&self.expressions[high],
|
||||
) {
|
||||
(&Expression::Literal(e), &Expression::Literal(low), &Expression::Literal(high)) => {
|
||||
let literal = match (e, low, high) {
|
||||
(Literal::I32(e), Literal::I32(low), Literal::I32(high)) => {
|
||||
if low > high {
|
||||
return Err(ConstantEvaluatorError::InvalidClamp);
|
||||
} else {
|
||||
Literal::I32(e.clamp(low, high))
|
||||
}
|
||||
}
|
||||
(Literal::U32(e), Literal::U32(low), Literal::U32(high)) => {
|
||||
if low > high {
|
||||
return Err(ConstantEvaluatorError::InvalidClamp);
|
||||
} else {
|
||||
Literal::U32(e.clamp(low, high))
|
||||
}
|
||||
}
|
||||
(Literal::F32(e), Literal::F32(low), Literal::F32(high)) => {
|
||||
if low > high {
|
||||
return Err(ConstantEvaluatorError::InvalidClamp);
|
||||
} else {
|
||||
Literal::F32(e.clamp(low, high))
|
||||
}
|
||||
}
|
||||
_ => return Err(ConstantEvaluatorError::InvalidMathArg),
|
||||
};
|
||||
Expression::Literal(literal)
|
||||
}
|
||||
(
|
||||
&Expression::Compose {
|
||||
components: ref src_components0,
|
||||
ty: ty0,
|
||||
},
|
||||
&Expression::Compose {
|
||||
components: ref src_components1,
|
||||
ty: ty1,
|
||||
},
|
||||
&Expression::Compose {
|
||||
components: ref src_components2,
|
||||
ty: ty2,
|
||||
},
|
||||
) if ty0 == ty1
|
||||
&& ty0 == ty2
|
||||
&& matches!(
|
||||
self.types[ty0].inner,
|
||||
crate::TypeInner::Vector {
|
||||
kind: crate::ScalarKind::Float,
|
||||
..
|
||||
}
|
||||
) =>
|
||||
{
|
||||
let mut components: Vec<_> = crate::proc::flatten_compose(
|
||||
ty0,
|
||||
src_components0,
|
||||
self.expressions,
|
||||
self.types,
|
||||
)
|
||||
.chain(crate::proc::flatten_compose(
|
||||
ty1,
|
||||
src_components1,
|
||||
self.expressions,
|
||||
self.types,
|
||||
))
|
||||
.chain(crate::proc::flatten_compose(
|
||||
ty2,
|
||||
src_components2,
|
||||
self.expressions,
|
||||
self.types,
|
||||
))
|
||||
.collect();
|
||||
|
||||
let chunk_size = components.len() / 3;
|
||||
let (es, rem) = components.split_at_mut(chunk_size);
|
||||
let (lows, highs) = rem.split_at(chunk_size);
|
||||
for ((e, low), high) in es.iter_mut().zip(lows).zip(highs) {
|
||||
*e = self.math_clamp(*e, *low, *high, span)?;
|
||||
}
|
||||
components.truncate(chunk_size);
|
||||
|
||||
Expression::Compose {
|
||||
ty: ty0,
|
||||
components,
|
||||
}
|
||||
}
|
||||
_ => return Err(ConstantEvaluatorError::InvalidMathArg),
|
||||
};
|
||||
|
||||
Ok(self.register_evaluated_expr(expr, span))
|
||||
}
|
||||
|
||||
fn array_length(
|
||||
&mut self,
|
||||
array: Handle<Expression>,
|
||||
@@ -1000,6 +1123,8 @@ impl<'a> ConstantEvaluator<'a> {
|
||||
}
|
||||
|
||||
fn register_evaluated_expr(&mut self, expr: Expression, span: Span) -> Handle<Expression> {
|
||||
// TODO: use the validate_literal function from https://github.com/gfx-rs/naga/pull/2508 here
|
||||
|
||||
if let Some(FunctionLocalData {
|
||||
ref mut emitter,
|
||||
ref mut block,
|
||||
@@ -1026,57 +1151,6 @@ impl<'a> ConstantEvaluator<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper function to implement the GLSL `max` function for floats.
|
||||
///
|
||||
/// While Rust does provide a `f64::max` method, it has a different behavior than the
|
||||
/// GLSL `max` for NaNs. In Rust, if any of the arguments is a NaN, then the other
|
||||
/// is returned.
|
||||
///
|
||||
/// This leads to different results in the following example
|
||||
/// ```
|
||||
/// use std::cmp::max;
|
||||
/// std::f64::NAN.max(1.0);
|
||||
/// ```
|
||||
///
|
||||
/// Rust will return `1.0` while GLSL should return NaN.
|
||||
fn glsl_float_max(x: f32, y: f32) -> f32 {
|
||||
if x < y {
|
||||
y
|
||||
} else {
|
||||
x
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper function to implement the GLSL `min` function for floats.
|
||||
///
|
||||
/// While Rust does provide a `f64::min` method, it has a different behavior than the
|
||||
/// GLSL `min` for NaNs. In Rust, if any of the arguments is a NaN, then the other
|
||||
/// is returned.
|
||||
///
|
||||
/// This leads to different results in the following example
|
||||
/// ```
|
||||
/// use std::cmp::min;
|
||||
/// std::f64::NAN.min(1.0);
|
||||
/// ```
|
||||
///
|
||||
/// Rust will return `1.0` while GLSL should return NaN.
|
||||
fn glsl_float_min(x: f32, y: f32) -> f32 {
|
||||
if y < x {
|
||||
y
|
||||
} else {
|
||||
x
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper function to implement the GLSL `clamp` function for floats.
|
||||
///
|
||||
/// While Rust does provide a `f64::clamp` method, it panics if either
|
||||
/// `min` or `max` are `NaN`s which is not the behavior specified by
|
||||
/// the glsl specification.
|
||||
fn glsl_float_clamp(value: f32, min: f32, max: f32) -> f32 {
|
||||
glsl_float_min(glsl_float_max(value, min), max)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::vec;
|
||||
@@ -1088,19 +1162,6 @@ mod tests {
|
||||
|
||||
use super::{Behavior, ConstantEvaluator};
|
||||
|
||||
#[test]
|
||||
fn nan_handling() {
|
||||
assert!(super::glsl_float_max(f32::NAN, 2.0).is_nan());
|
||||
assert!(!super::glsl_float_max(2.0, f32::NAN).is_nan());
|
||||
|
||||
assert!(super::glsl_float_min(f32::NAN, 2.0).is_nan());
|
||||
assert!(!super::glsl_float_min(2.0, f32::NAN).is_nan());
|
||||
|
||||
assert!(super::glsl_float_clamp(f32::NAN, 1.0, 2.0).is_nan());
|
||||
assert!(!super::glsl_float_clamp(1.0, f32::NAN, 2.0).is_nan());
|
||||
assert!(!super::glsl_float_clamp(1.0, 2.0, f32::NAN).is_nan());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unary_op() {
|
||||
let mut types = UniqueArena::new();
|
||||
|
||||
Reference in New Issue
Block a user