diff --git a/naga/src/proc/constant_evaluator.rs b/naga/src/proc/constant_evaluator.rs index 42ba489790..c8911077b7 100644 --- a/naga/src/proc/constant_evaluator.rs +++ b/naga/src/proc/constant_evaluator.rs @@ -1398,14 +1398,19 @@ impl<'a> ConstantEvaluator<'a> { /// [`Compose`]: Expression::Compose fn eval_zero_value_and_splat( &mut self, - expr: Handle, + mut expr: Handle, span: Span, ) -> Result, ConstantEvaluatorError> { - match self.expressions[expr] { - Expression::ZeroValue(ty) => self.eval_zero_value_impl(ty, span), - Expression::Splat { size, value } => self.splat(value, size, span), - _ => Ok(expr), + // The result of the splat() for a Splat of a scalar ZeroValue is a + // vector ZeroValue, so we must call eval_zero_value_impl() after + // splat() in order to ensure we have no ZeroValues remaining. + if let Expression::Splat { size, value } = self.expressions[expr] { + expr = self.splat(value, size, span)?; } + if let Expression::ZeroValue(ty) = self.expressions[expr] { + expr = self.eval_zero_value_impl(ty, span)?; + } + Ok(expr) } /// Lower [`ZeroValue`] expressions to [`Literal`] and [`Compose`] expressions. @@ -2976,4 +2981,84 @@ mod tests { panic!("unexpected evaluation result") } } + + #[test] + fn splat_of_zero_value() { + let mut types = UniqueArena::new(); + let constants = Arena::new(); + let overrides = Arena::new(); + let mut global_expressions = Arena::new(); + + let f32_ty = types.insert( + Type { + name: None, + inner: TypeInner::Scalar(crate::Scalar::F32), + }, + Default::default(), + ); + + let vec2_f32_ty = types.insert( + Type { + name: None, + inner: TypeInner::Vector { + size: VectorSize::Bi, + scalar: crate::Scalar::F32, + }, + }, + Default::default(), + ); + + let five = + global_expressions.append(Expression::Literal(Literal::F32(5.0)), Default::default()); + let five_splat = global_expressions.append( + Expression::Splat { + size: VectorSize::Bi, + value: five, + }, + Default::default(), + ); + let zero = global_expressions.append(Expression::ZeroValue(f32_ty), Default::default()); + let zero_splat = global_expressions.append( + Expression::Splat { + size: VectorSize::Bi, + value: zero, + }, + Default::default(), + ); + + let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions); + let mut solver = ConstantEvaluator { + behavior: Behavior::Wgsl(WgslRestrictions::Const(None)), + types: &mut types, + constants: &constants, + overrides: &overrides, + expressions: &mut global_expressions, + expression_kind_tracker, + }; + + let solved_add = solver + .try_eval_and_append( + Expression::Binary { + op: crate::BinaryOperator::Add, + left: zero_splat, + right: five_splat, + }, + Default::default(), + ) + .unwrap(); + + let pass = match global_expressions[solved_add] { + Expression::Compose { ty, ref components } => { + ty == vec2_f32_ty + && components.iter().all(|&component| { + let component = &global_expressions[component]; + matches!(*component, Expression::Literal(Literal::F32(5.0))) + }) + } + _ => false, + }; + if !pass { + panic!("unexpected evaluation result") + } + } }