From ca6876b7a0be2597136564b6bbafb9260d627083 Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Sun, 28 Mar 2021 23:33:29 -0400 Subject: [PATCH] Validate math functions --- src/front/wgsl/tests.rs | 13 ++ src/valid/expression.rs | 301 ++++++++++++++++++++++++++++++++++++---- 2 files changed, 286 insertions(+), 28 deletions(-) diff --git a/src/front/wgsl/tests.rs b/src/front/wgsl/tests.rs index fbad55201c..b439a26b88 100644 --- a/src/front/wgsl/tests.rs +++ b/src/front/wgsl/tests.rs @@ -226,6 +226,19 @@ fn parse_texture_load() { .unwrap(); } +#[test] +fn parse_texture_store() { + parse_str( + " + var t: [[access(write)]] texture_storage_2d; + fn foo() { + textureStore(t, vec2(10, 20), vec4(0.0, 1.0, 2.0, 3.0)); + } + ", + ) + .unwrap(); +} + #[test] fn parse_texture_query() { parse_str( diff --git a/src/valid/expression.rs b/src/valid/expression.rs index 5a5ea24563..b21f534886 100644 --- a/src/valid/expression.rs +++ b/src/valid/expression.rs @@ -97,6 +97,10 @@ pub enum ExpressionError { InvalidSampleLevelGradientType(crate::ImageDimension, Handle), #[error("Unable to cast")] InvalidCastArgument, + #[error("Invalid argument count for {0:?}")] + WrongArgumentCount(crate::MathFunction), + #[error("Argument [{1}] to {0:?} as expression {2:?} has an invalid type.")] + InvalidArgumentType(crate::MathFunction, u32, Handle), } struct ExpressionTypeResolver<'a> { @@ -401,7 +405,7 @@ impl super::Validator { if let Some(expr) = array_index { match *resolver.resolve(expr)? { Ti::Scalar { - kind: crate::ScalarKind::Sint, + kind: Sk::Sint, width: _, } => {} _ => return Err(ExpressionError::InvalidImageArrayIndexType(expr)), @@ -419,12 +423,11 @@ impl super::Validator { }; match *resolver.resolve(coordinate)? { Ti::Scalar { - kind: crate::ScalarKind::Float, - .. + kind: Sk::Float, .. } if num_components == 1 => {} Ti::Vector { size, - kind: crate::ScalarKind::Float, + kind: Sk::Float, .. } if size as u32 == num_components => {} _ => return Err(ExpressionError::InvalidImageCoordinateType(dim, coordinate)), @@ -440,7 +443,7 @@ impl super::Validator { match module.types[ty].inner { Ti::Vector { size, - kind: crate::ScalarKind::Sint, + kind: Sk::Sint, .. } => size as u32 == num_components, _ => false, @@ -455,8 +458,7 @@ impl super::Validator { if let Some(expr) = depth_ref { match *resolver.resolve(expr)? { Ti::Scalar { - kind: crate::ScalarKind::Float, - .. + kind: Sk::Float, .. } => {} _ => return Err(ExpressionError::InvalidDepthReference(expr)), } @@ -474,8 +476,7 @@ impl super::Validator { crate::SampleLevel::Exact(expr) if can_level => { match *resolver.resolve(expr)? { Ti::Scalar { - kind: crate::ScalarKind::Float, - .. + kind: Sk::Float, .. } => {} _ => return Err(ExpressionError::InvalidSampleLevelExactType(expr)), } @@ -484,8 +485,7 @@ impl super::Validator { crate::SampleLevel::Bias(expr) => { match *resolver.resolve(expr)? { Ti::Scalar { - kind: crate::ScalarKind::Float, - .. + kind: Sk::Float, .. } => {} _ => return Err(ExpressionError::InvalidSampleLevelBiasType(expr)), } @@ -494,12 +494,11 @@ impl super::Validator { crate::SampleLevel::Gradient { x, y } => { match *resolver.resolve(x)? { Ti::Scalar { - kind: crate::ScalarKind::Float, - .. + kind: Sk::Float, .. } if num_components == 1 => {} Ti::Vector { size, - kind: crate::ScalarKind::Float, + kind: Sk::Float, .. } if size as u32 == num_components => {} _ => { @@ -508,12 +507,11 @@ impl super::Validator { } match *resolver.resolve(y)? { Ti::Scalar { - kind: crate::ScalarKind::Float, - .. + kind: Sk::Float, .. } if num_components == 1 => {} Ti::Vector { size, - kind: crate::ScalarKind::Float, + kind: Sk::Float, .. } if size as u32 == num_components => {} _ => { @@ -564,7 +562,7 @@ impl super::Validator { if let Some(expr) = array_index { match *resolver.resolve(expr)? { Ti::Scalar { - kind: crate::ScalarKind::Sint, + kind: Sk::Sint, width: _, } => {} _ => return Err(ExpressionError::InvalidImageArrayIndexType(expr)), @@ -573,7 +571,7 @@ impl super::Validator { if let Some(expr) = index { match *resolver.resolve(expr)? { Ti::Scalar { - kind: crate::ScalarKind::Sint, + kind: Sk::Sint, width: _, } => {} _ => return Err(ExpressionError::InvalidImageOtherIndexType(expr)), @@ -807,12 +805,10 @@ impl super::Validator { E::Derivative { axis: _, expr } => { match *resolver.resolve(expr)? { Ti::Scalar { - kind: crate::ScalarKind::Float, - .. + kind: Sk::Float, .. } | Ti::Vector { - kind: crate::ScalarKind::Float, - .. + kind: Sk::Float, .. } => {} _ => return Err(ExpressionError::InvalidDerivative), } @@ -844,13 +840,264 @@ impl super::Validator { } ShaderStages::all() } - #[allow(unused)] E::Math { fun, arg, arg1, arg2, - } => ShaderStages::all(), + } => { + use crate::MathFunction as Mf; + + let arg_ty = resolver.resolve(arg)?; + let arg1_ty = arg1.map(|expr| resolver.resolve(expr)).transpose()?; + let arg2_ty = arg2.map(|expr| resolver.resolve(expr)).transpose()?; + match fun { + Mf::Abs => { + if arg1_ty.is_some() | arg2_ty.is_some() { + return Err(ExpressionError::WrongArgumentCount(fun)); + } + let good = match *arg_ty { + Ti::Scalar { kind, .. } | Ti::Vector { kind, .. } => kind != Sk::Bool, + _ => false, + }; + if !good { + return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)); + } + } + Mf::Min | Mf::Max => { + let arg1_ty = match (arg1_ty, arg2_ty) { + (Some(ty1), None) => ty1, + _ => return Err(ExpressionError::WrongArgumentCount(fun)), + }; + let good = match *arg_ty { + Ti::Scalar { kind, .. } | Ti::Vector { kind, .. } => kind != Sk::Bool, + _ => false, + }; + if !good { + return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)); + } + if arg1_ty != arg_ty { + return Err(ExpressionError::InvalidArgumentType( + fun, + 1, + arg1.unwrap(), + )); + } + } + Mf::Clamp => { + let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty) { + (Some(ty1), Some(ty2)) => (ty1, ty2), + _ => return Err(ExpressionError::WrongArgumentCount(fun)), + }; + let good = match *arg_ty { + Ti::Scalar { kind, .. } | Ti::Vector { kind, .. } => kind != Sk::Bool, + _ => false, + }; + if !good { + return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)); + } + if arg1_ty != arg_ty { + return Err(ExpressionError::InvalidArgumentType( + fun, + 1, + arg1.unwrap(), + )); + } + if arg2_ty != arg_ty { + return Err(ExpressionError::InvalidArgumentType( + fun, + 2, + arg2.unwrap(), + )); + } + } + Mf::Cos + | Mf::Cosh + | Mf::Sin + | Mf::Sinh + | Mf::Tan + | Mf::Tanh + | Mf::Acos + | Mf::Asin + | Mf::Atan + | Mf::Ceil + | Mf::Floor + | Mf::Round + | Mf::Fract + | Mf::Trunc + | Mf::Exp + | Mf::Exp2 + | Mf::Log + | Mf::Log2 + | Mf::Length + | Mf::Sign + | Mf::Sqrt + | Mf::InverseSqrt => { + if arg1_ty.is_some() | arg2_ty.is_some() { + return Err(ExpressionError::WrongArgumentCount(fun)); + } + match *arg_ty { + Ti::Scalar { + kind: Sk::Float, .. + } + | Ti::Vector { + kind: Sk::Float, .. + } => {} + _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), + } + } + Mf::Atan2 | Mf::Pow | Mf::Distance => { + let arg1_ty = match (arg1_ty, arg2_ty) { + (Some(ty1), None) => ty1, + _ => return Err(ExpressionError::WrongArgumentCount(fun)), + }; + match *arg_ty { + Ti::Scalar { + kind: Sk::Float, .. + } + | Ti::Vector { + kind: Sk::Float, .. + } => {} + _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), + } + if arg1_ty != arg_ty { + return Err(ExpressionError::InvalidArgumentType( + fun, + 1, + arg1.unwrap(), + )); + } + } + Mf::Modf | Mf::Frexp | Mf::Ldexp => { + let arg1_ty = match (arg1_ty, arg2_ty) { + (Some(ty1), None) => ty1, + _ => return Err(ExpressionError::WrongArgumentCount(fun)), + }; + let (size0, width0) = match *arg_ty { + Ti::Scalar { + kind: Sk::Float, + width, + } => (None, width), + Ti::Vector { + kind: Sk::Float, + size, + width, + } => (Some(size), width), + _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), + }; + let good = match *arg1_ty { + Ti::Pointer { base, class: _ } => module.types[base].inner == *arg_ty, + Ti::ValuePointer { + size, + kind: Sk::Float, + width, + class: _, + } => size == size0 && width == width0, + _ => false, + }; + if !good { + return Err(ExpressionError::InvalidArgumentType( + fun, + 1, + arg1.unwrap(), + )); + } + } + Mf::Dot | Mf::Outer | Mf::Cross | Mf::Step | Mf::Reflect => { + let arg1_ty = match (arg1_ty, arg2_ty) { + (Some(ty1), None) => ty1, + _ => return Err(ExpressionError::WrongArgumentCount(fun)), + }; + match *arg_ty { + Ti::Vector { + kind: Sk::Float, .. + } => {} + _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), + } + if arg1_ty != arg_ty { + return Err(ExpressionError::InvalidArgumentType( + fun, + 1, + arg1.unwrap(), + )); + } + } + Mf::Normalize => { + if arg1_ty.is_some() | arg2_ty.is_some() { + return Err(ExpressionError::WrongArgumentCount(fun)); + } + match *arg_ty { + Ti::Vector { + kind: Sk::Float, .. + } => {} + _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), + } + } + Mf::FaceForward | Mf::Fma | Mf::Mix | Mf::SmoothStep => { + let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty) { + (Some(ty1), Some(ty2)) => (ty1, ty2), + _ => return Err(ExpressionError::WrongArgumentCount(fun)), + }; + match *arg_ty { + Ti::Scalar { + kind: Sk::Float, .. + } + | Ti::Vector { + kind: Sk::Float, .. + } => {} + _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), + } + if arg1_ty != arg_ty { + return Err(ExpressionError::InvalidArgumentType( + fun, + 1, + arg1.unwrap(), + )); + } + if arg2_ty != arg_ty { + return Err(ExpressionError::InvalidArgumentType( + fun, + 2, + arg2.unwrap(), + )); + } + } + Mf::Inverse | Mf::Determinant => { + if arg1_ty.is_some() | arg2_ty.is_some() { + return Err(ExpressionError::WrongArgumentCount(fun)); + } + let good = match *arg_ty { + Ti::Matrix { columns, rows, .. } => columns == rows, + _ => false, + }; + if !good { + return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)); + } + } + Mf::Transpose => { + if arg1_ty.is_some() | arg2_ty.is_some() { + return Err(ExpressionError::WrongArgumentCount(fun)); + } + match *arg_ty { + Ti::Matrix { .. } => {} + _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), + } + } + Mf::CountOneBits | Mf::ReverseBits => { + if arg1_ty.is_some() | arg2_ty.is_some() { + return Err(ExpressionError::WrongArgumentCount(fun)); + } + match *arg_ty { + Ti::Scalar { kind: Sk::Sint, .. } + | Ti::Scalar { kind: Sk::Uint, .. } + | Ti::Vector { kind: Sk::Sint, .. } + | Ti::Vector { kind: Sk::Uint, .. } => {} + _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), + } + } + } + ShaderStages::all() + } E::As { expr, kind, @@ -860,9 +1107,7 @@ impl super::Validator { .resolve(expr)? .scalar_kind() .ok_or(ExpressionError::InvalidCastArgument)?; - if !convert && prev_kind == crate::ScalarKind::Bool - || kind == crate::ScalarKind::Bool - { + if !convert && prev_kind == Sk::Bool || kind == Sk::Bool { return Err(ExpressionError::InvalidCastArgument); } ShaderStages::all()