diff --git a/src/proc/validator.rs b/src/proc/validator.rs index d18240c872..4ca67bf3ac 100644 --- a/src/proc/validator.rs +++ b/src/proc/validator.rs @@ -189,6 +189,12 @@ pub enum ExpressionError { Handle, Handle, ), + #[error("Selecting is not possible")] + InvalidSelectTypes, + #[error("Relational argument {0:?} is not a boolean vector")] + InvalidBooleanVector(Handle), + #[error("Relational argument {0:?} is not a float")] + InvalidFloatArgument(Handle), } #[derive(Clone, Debug, Error)] @@ -358,6 +364,24 @@ fn storage_usage(access: crate::StorageAccess) -> GlobalUse { storage_usage } +impl crate::TypeInner { + fn is_sized(&self) -> bool { + match *self { + Self::Scalar { .. } + | Self::Vector { .. } + | Self::Matrix { .. } + | Self::Array { + size: crate::ArraySize::Constant(_), + .. + } + | Self::Pointer { .. } + | Self::ValuePointer { .. } + | Self::Struct { .. } => true, + Self::Array { .. } | Self::Image { .. } | Self::Sampler { .. } => false, + } + } +} + struct VaryingContext<'a> { ty: Handle, stage: crate::ShaderStage, @@ -892,7 +916,6 @@ impl Validator { root: Handle, expression: &crate::Expression, function: &crate::Function, - stage: Option, module: &crate::Module, ) -> Result<(), ExpressionError> { use crate::{Expression as E, ScalarKind as Sk, TypeInner as Ti}; @@ -1214,19 +1237,7 @@ impl Validator { }; kind_match && types_match && left_width == right_width } - Bo::Equal | Bo::NotEqual => match *left_inner { - Ti::Scalar { .. } - | Ti::Vector { .. } - | Ti::Matrix { .. } - | Ti::Array { - size: crate::ArraySize::Constant(_), - .. - } - | Ti::Pointer { .. } - | Ti::ValuePointer { .. } - | Ti::Struct { .. } => left_inner == right_inner, - Ti::Array { .. } | Ti::Image { .. } | Ti::Sampler { .. } => false, - }, + Bo::Equal | Bo::NotEqual => left_inner.is_sized() && left_inner == right_inner, Bo::Less | Bo::LessEqual | Bo::Greater | Bo::GreaterEqual => { match *left_inner { Ti::Scalar { kind, .. } | Ti::Vector { kind, .. } => match kind { @@ -1293,9 +1304,58 @@ impl Validator { condition, accept, reject, - } => {} - E::Derivative { axis, expr } => {} - E::Relational { argument, .. } => {} + } => { + let accept_inner = resolver.resolve(accept)?; + let reject_inner = resolver.resolve(reject)?; + let condition_good = match *resolver.resolve(condition)? { + Ti::Scalar { + kind: Sk::Bool, + width: _, + } => accept_inner.is_sized(), + Ti::Vector { + size, + kind: Sk::Bool, + width: _, + } => match *accept_inner { + Ti::Vector { + size: other_size, .. + } => size == other_size, + _ => false, + }, + _ => false, + }; + if !condition_good || accept_inner != reject_inner { + return Err(ExpressionError::InvalidSelectTypes); + } + } + E::Derivative { axis, expr } => { + //TODO: check stage + } + E::Relational { fun, argument } => { + use crate::RelationalFunction as Rf; + let argument_inner = resolver.resolve(argument)?; + match fun { + Rf::All | Rf::Any => match *argument_inner { + Ti::Vector { kind: Sk::Bool, .. } => {} + ref other => { + log::error!("All/Any of type {:?}", other); + return Err(ExpressionError::InvalidBooleanVector(argument)); + } + }, + Rf::IsNan | Rf::IsInf | Rf::IsFinite | Rf::IsNormal => match *argument_inner { + Ti::Scalar { + kind: Sk::Float, .. + } + | Ti::Vector { + kind: Sk::Float, .. + } => {} + ref other => { + log::error!("Float test of type {:?}", other); + return Err(ExpressionError::InvalidFloatArgument(argument)); + } + }, + } + } E::Math { fun, arg, @@ -1551,7 +1611,6 @@ impl Validator { fun: &crate::Function, _info: &FunctionInfo, module: &crate::Module, - stage: Option, ) -> Result<(), FunctionError> { let resolve_ctx = ResolveContext { constants: &module.constants, @@ -1586,7 +1645,7 @@ impl Validator { if expr.needs_pre_emit() { self.valid_expression_set.insert(handle.index()); } - if let Err(error) = self.validate_expression(handle, expr, fun, stage, module) { + if let Err(error) = self.validate_expression(handle, expr, fun, module) { return Err(FunctionError::Expression { handle, error }); } } @@ -1693,7 +1752,7 @@ impl Validator { } } - self.validate_function(&ep.function, info, module, Some(ep.stage))?; + self.validate_function(&ep.function, info, module)?; Ok(()) } @@ -1737,7 +1796,7 @@ impl Validator { } for (handle, fun) in module.functions.iter() { - self.validate_function(fun, &analysis[handle], module, None) + self.validate_function(fun, &analysis[handle], module) .map_err(|error| ValidationError::Function { handle, name: fun.name.clone().unwrap_or_default(),