diff --git a/src/proc/constant_evaluator.rs b/src/proc/constant_evaluator.rs index eb45179b71..61adf0fdd5 100644 --- a/src/proc/constant_evaluator.rs +++ b/src/proc/constant_evaluator.rs @@ -78,7 +78,7 @@ impl ExpressionConstnessTracker { } } -#[derive(Clone, Debug, PartialEq, thiserror::Error)] +#[derive(Clone, Debug, thiserror::Error)] pub enum ConstantEvaluatorError { #[error("Constants cannot access function arguments")] FunctionArg, @@ -144,6 +144,8 @@ pub enum ConstantEvaluatorError { RemainderByZero, #[error("RHS of shift operation is greater than or equal to 32")] ShiftedMoreThan32Bits, + #[error(transparent)] + Literal(#[from] crate::valid::LiteralError), } impl<'a> ConstantEvaluator<'a> { @@ -270,18 +272,18 @@ impl<'a> ConstantEvaluator<'a> { Ok(self.constants[c].init) } Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => { - Ok(self.register_evaluated_expr(expr.clone(), span)) + self.register_evaluated_expr(expr.clone(), span) } Expression::Compose { ty, ref components } => { let components = components .iter() .map(|component| self.check_and_get(*component)) .collect::, _>>()?; - Ok(self.register_evaluated_expr(Expression::Compose { ty, components }, span)) + self.register_evaluated_expr(Expression::Compose { ty, components }, span) } Expression::Splat { size, value } => { let value = self.check_and_get(value)?; - Ok(self.register_evaluated_expr(Expression::Splat { size, value }, span)) + self.register_evaluated_expr(Expression::Splat { size, value }, span) } Expression::AccessIndex { base, index } => { let base = self.check_and_get(base)?; @@ -395,7 +397,7 @@ impl<'a> ConstantEvaluator<'a> { ty, components: vec![value; size as usize], }; - Ok(self.register_evaluated_expr(expr, span)) + self.register_evaluated_expr(expr, span) } Expression::ZeroValue(ty) => { let inner = match self.types[ty].inner { @@ -404,7 +406,7 @@ impl<'a> ConstantEvaluator<'a> { }; let res_ty = self.types.insert(Type { name: None, inner }, span); let expr = Expression::ZeroValue(res_ty); - Ok(self.register_evaluated_expr(expr, span)) + self.register_evaluated_expr(expr, span) } _ => Err(ConstantEvaluatorError::SplatScalarOnly), } @@ -436,11 +438,11 @@ impl<'a> ConstantEvaluator<'a> { Expression::ZeroValue(ty) => { let dst_ty = get_dst_ty(ty)?; let expr = Expression::ZeroValue(dst_ty); - Ok(self.register_evaluated_expr(expr, span)) + self.register_evaluated_expr(expr, span) } Expression::Splat { value, .. } => { let expr = Expression::Splat { size, value }; - Ok(self.register_evaluated_expr(expr, span)) + self.register_evaluated_expr(expr, span) } Expression::Compose { ty, ref components } => { let dst_ty = get_dst_ty(ty)?; @@ -468,7 +470,7 @@ impl<'a> ConstantEvaluator<'a> { ty: dst_ty, components: swizzled_components, }; - Ok(self.register_evaluated_expr(expr, span)) + self.register_evaluated_expr(expr, span) } _ => Err(ConstantEvaluatorError::SwizzleVectorOnly), } @@ -565,7 +567,7 @@ impl<'a> ConstantEvaluator<'a> { _ => return Err(ConstantEvaluatorError::InvalidMathArg), }; - Ok(self.register_evaluated_expr(expr, span)) + self.register_evaluated_expr(expr, span) } fn math_clamp( @@ -670,7 +672,7 @@ impl<'a> ConstantEvaluator<'a> { _ => return Err(ConstantEvaluatorError::InvalidMathArg), }; - Ok(self.register_evaluated_expr(expr, span)) + self.register_evaluated_expr(expr, span) } fn array_length( @@ -684,7 +686,7 @@ impl<'a> ConstantEvaluator<'a> { TypeInner::Array { size, .. } => match size { crate::ArraySize::Constant(len) => { let expr = Expression::Literal(Literal::U32(len.get())); - Ok(self.register_evaluated_expr(expr, span)) + self.register_evaluated_expr(expr, span) } crate::ArraySize::Dynamic => { Err(ConstantEvaluatorError::ArrayLengthDynamic) @@ -722,7 +724,7 @@ impl<'a> ConstantEvaluator<'a> { self.types.insert(Type { name: None, inner }, span) } }; - Ok(self.register_evaluated_expr(Expression::ZeroValue(ty), span)) + self.register_evaluated_expr(Expression::ZeroValue(ty), span) } } Expression::Splat { size, value } => { @@ -788,7 +790,7 @@ impl<'a> ConstantEvaluator<'a> { Literal::zero(kind, width) .ok_or(ConstantEvaluatorError::TypeNotConstructible)?, ); - Ok(self.register_evaluated_expr(expr, span)) + self.register_evaluated_expr(expr, span) } TypeInner::Vector { size, kind, width } => { let scalar_ty = self.types.insert( @@ -803,7 +805,7 @@ impl<'a> ConstantEvaluator<'a> { ty, components: vec![el; size as usize], }; - Ok(self.register_evaluated_expr(expr, span)) + self.register_evaluated_expr(expr, span) } TypeInner::Matrix { columns, @@ -826,7 +828,7 @@ impl<'a> ConstantEvaluator<'a> { ty, components: vec![el; columns as usize], }; - Ok(self.register_evaluated_expr(expr, span)) + self.register_evaluated_expr(expr, span) } TypeInner::Array { base, @@ -838,7 +840,7 @@ impl<'a> ConstantEvaluator<'a> { ty, components: vec![el; size.get() as usize], }; - Ok(self.register_evaluated_expr(expr, span)) + self.register_evaluated_expr(expr, span) } TypeInner::Struct { ref members, .. } => { let types: Vec<_> = members.iter().map(|m| m.ty).collect(); @@ -847,7 +849,7 @@ impl<'a> ConstantEvaluator<'a> { components.push(self.eval_zero_value_impl(ty, span)?); } let expr = Expression::Compose { ty, components }; - Ok(self.register_evaluated_expr(expr, span)) + self.register_evaluated_expr(expr, span) } _ => Err(ConstantEvaluatorError::TypeNotConstructible), } @@ -933,7 +935,7 @@ impl<'a> ConstantEvaluator<'a> { _ => return Err(ConstantEvaluatorError::InvalidCastArg), }; - Ok(self.register_evaluated_expr(expr, span)) + self.register_evaluated_expr(expr, span) } fn unary_op( @@ -977,7 +979,7 @@ impl<'a> ConstantEvaluator<'a> { _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg), }; - Ok(self.register_evaluated_expr(expr, span)) + self.register_evaluated_expr(expr, span) } fn binary_op( @@ -1113,7 +1115,7 @@ impl<'a> ConstantEvaluator<'a> { _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs), }; - Ok(self.register_evaluated_expr(expr, span)) + self.register_evaluated_expr(expr, span) } /// Deep copy `expr` from `expressions` into `self.expressions`. @@ -1132,17 +1134,17 @@ impl<'a> ConstantEvaluator<'a> { match expressions[expr] { ref expr @ (Expression::Literal(_) | Expression::Constant(_) - | Expression::ZeroValue(_)) => Ok(self.register_evaluated_expr(expr.clone(), span)), + | Expression::ZeroValue(_)) => self.register_evaluated_expr(expr.clone(), span), Expression::Compose { ty, ref components } => { let mut components = components.clone(); for component in &mut components { *component = self.copy_from(*component, expressions)?; } - Ok(self.register_evaluated_expr(Expression::Compose { ty, components }, span)) + self.register_evaluated_expr(Expression::Compose { ty, components }, span) } Expression::Splat { size, value } => { let value = self.copy_from(value, expressions)?; - Ok(self.register_evaluated_expr(Expression::Splat { size, value }, span)) + self.register_evaluated_expr(Expression::Splat { size, value }, span) } _ => { log::debug!("copy_from: SubexpressionsAreNotConstant"); @@ -1151,8 +1153,17 @@ impl<'a> ConstantEvaluator<'a> { } } - fn register_evaluated_expr(&mut self, expr: Expression, span: Span) -> Handle { - // TODO: use the validate_literal function from https://github.com/gfx-rs/naga/pull/2508 here + fn register_evaluated_expr( + &mut self, + expr: Expression, + span: Span, + ) -> Result, ConstantEvaluatorError> { + // It suffices to only check literals, since we only register one + // expression at a time, `Compose` expressions can only refer to other + // expressions, and `ZeroValue` expressions are always okay. + if let Expression::Literal(literal) = expr { + crate::valid::validate_literal(literal)?; + } if let Some(FunctionLocalData { ref mut emitter, @@ -1168,14 +1179,14 @@ impl<'a> ConstantEvaluator<'a> { let h = self.expressions.append(expr, span); emitter.start(self.expressions); expression_constness.insert(h); - h + Ok(h) } else { let h = self.expressions.append(expr, span); expression_constness.insert(h); - h + Ok(h) } } else { - self.expressions.append(expr, span) + Ok(self.expressions.append(expr, span)) } } } diff --git a/src/valid/expression.rs b/src/valid/expression.rs index af4b774f12..95225a3926 100644 --- a/src/valid/expression.rs +++ b/src/valid/expression.rs @@ -1566,7 +1566,7 @@ impl super::Validator { } } -fn validate_literal(literal: crate::Literal) -> Result<(), LiteralError> { +pub fn validate_literal(literal: crate::Literal) -> Result<(), LiteralError> { let is_nan = match literal { crate::Literal::F64(v) => v.is_nan(), crate::Literal::F32(v) => v.is_nan(), diff --git a/src/valid/mod.rs b/src/valid/mod.rs index 6175aa0945..8c065bb159 100644 --- a/src/valid/mod.rs +++ b/src/valid/mod.rs @@ -24,6 +24,7 @@ use std::ops; use crate::span::{AddSpan as _, WithSpan}; pub use analyzer::{ExpressionInfo, FunctionInfo, GlobalUse, Uniformity, UniformityRequirements}; pub use compose::ComposeError; +pub use expression::{validate_literal, LiteralError}; pub use expression::{ConstExpressionError, ExpressionError}; pub use function::{CallError, FunctionError, LocalVariableError}; pub use interface::{EntryPointError, GlobalVariableError, VaryingError};