Validate math functions

This commit is contained in:
Dzmitry Malyshau
2021-03-28 23:33:29 -04:00
committed by Dzmitry Malyshau
parent 50059740f7
commit ca6876b7a0
2 changed files with 286 additions and 28 deletions

View File

@@ -226,6 +226,19 @@ fn parse_texture_load() {
.unwrap();
}
#[test]
fn parse_texture_store() {
parse_str(
"
var t: [[access(write)]] texture_storage_2d<rgba8unorm>;
fn foo() {
textureStore(t, vec2<i32>(10, 20), vec4<f32>(0.0, 1.0, 2.0, 3.0));
}
",
)
.unwrap();
}
#[test]
fn parse_texture_query() {
parse_str(

View File

@@ -97,6 +97,10 @@ pub enum ExpressionError {
InvalidSampleLevelGradientType(crate::ImageDimension, Handle<crate::Expression>),
#[error("Unable to cast")]
InvalidCastArgument,
#[error("Invalid argument count for {0:?}")]
WrongArgumentCount(crate::MathFunction),
#[error("Argument [{1}] to {0:?} as expression {2:?} has an invalid type.")]
InvalidArgumentType(crate::MathFunction, u32, Handle<crate::Expression>),
}
struct ExpressionTypeResolver<'a> {
@@ -401,7 +405,7 @@ impl super::Validator {
if let Some(expr) = array_index {
match *resolver.resolve(expr)? {
Ti::Scalar {
kind: crate::ScalarKind::Sint,
kind: Sk::Sint,
width: _,
} => {}
_ => return Err(ExpressionError::InvalidImageArrayIndexType(expr)),
@@ -419,12 +423,11 @@ impl super::Validator {
};
match *resolver.resolve(coordinate)? {
Ti::Scalar {
kind: crate::ScalarKind::Float,
..
kind: Sk::Float, ..
} if num_components == 1 => {}
Ti::Vector {
size,
kind: crate::ScalarKind::Float,
kind: Sk::Float,
..
} if size as u32 == num_components => {}
_ => return Err(ExpressionError::InvalidImageCoordinateType(dim, coordinate)),
@@ -440,7 +443,7 @@ impl super::Validator {
match module.types[ty].inner {
Ti::Vector {
size,
kind: crate::ScalarKind::Sint,
kind: Sk::Sint,
..
} => size as u32 == num_components,
_ => false,
@@ -455,8 +458,7 @@ impl super::Validator {
if let Some(expr) = depth_ref {
match *resolver.resolve(expr)? {
Ti::Scalar {
kind: crate::ScalarKind::Float,
..
kind: Sk::Float, ..
} => {}
_ => return Err(ExpressionError::InvalidDepthReference(expr)),
}
@@ -474,8 +476,7 @@ impl super::Validator {
crate::SampleLevel::Exact(expr) if can_level => {
match *resolver.resolve(expr)? {
Ti::Scalar {
kind: crate::ScalarKind::Float,
..
kind: Sk::Float, ..
} => {}
_ => return Err(ExpressionError::InvalidSampleLevelExactType(expr)),
}
@@ -484,8 +485,7 @@ impl super::Validator {
crate::SampleLevel::Bias(expr) => {
match *resolver.resolve(expr)? {
Ti::Scalar {
kind: crate::ScalarKind::Float,
..
kind: Sk::Float, ..
} => {}
_ => return Err(ExpressionError::InvalidSampleLevelBiasType(expr)),
}
@@ -494,12 +494,11 @@ impl super::Validator {
crate::SampleLevel::Gradient { x, y } => {
match *resolver.resolve(x)? {
Ti::Scalar {
kind: crate::ScalarKind::Float,
..
kind: Sk::Float, ..
} if num_components == 1 => {}
Ti::Vector {
size,
kind: crate::ScalarKind::Float,
kind: Sk::Float,
..
} if size as u32 == num_components => {}
_ => {
@@ -508,12 +507,11 @@ impl super::Validator {
}
match *resolver.resolve(y)? {
Ti::Scalar {
kind: crate::ScalarKind::Float,
..
kind: Sk::Float, ..
} if num_components == 1 => {}
Ti::Vector {
size,
kind: crate::ScalarKind::Float,
kind: Sk::Float,
..
} if size as u32 == num_components => {}
_ => {
@@ -564,7 +562,7 @@ impl super::Validator {
if let Some(expr) = array_index {
match *resolver.resolve(expr)? {
Ti::Scalar {
kind: crate::ScalarKind::Sint,
kind: Sk::Sint,
width: _,
} => {}
_ => return Err(ExpressionError::InvalidImageArrayIndexType(expr)),
@@ -573,7 +571,7 @@ impl super::Validator {
if let Some(expr) = index {
match *resolver.resolve(expr)? {
Ti::Scalar {
kind: crate::ScalarKind::Sint,
kind: Sk::Sint,
width: _,
} => {}
_ => return Err(ExpressionError::InvalidImageOtherIndexType(expr)),
@@ -807,12 +805,10 @@ impl super::Validator {
E::Derivative { axis: _, expr } => {
match *resolver.resolve(expr)? {
Ti::Scalar {
kind: crate::ScalarKind::Float,
..
kind: Sk::Float, ..
}
| Ti::Vector {
kind: crate::ScalarKind::Float,
..
kind: Sk::Float, ..
} => {}
_ => return Err(ExpressionError::InvalidDerivative),
}
@@ -844,13 +840,264 @@ impl super::Validator {
}
ShaderStages::all()
}
#[allow(unused)]
E::Math {
fun,
arg,
arg1,
arg2,
} => ShaderStages::all(),
} => {
use crate::MathFunction as Mf;
let arg_ty = resolver.resolve(arg)?;
let arg1_ty = arg1.map(|expr| resolver.resolve(expr)).transpose()?;
let arg2_ty = arg2.map(|expr| resolver.resolve(expr)).transpose()?;
match fun {
Mf::Abs => {
if arg1_ty.is_some() | arg2_ty.is_some() {
return Err(ExpressionError::WrongArgumentCount(fun));
}
let good = match *arg_ty {
Ti::Scalar { kind, .. } | Ti::Vector { kind, .. } => kind != Sk::Bool,
_ => false,
};
if !good {
return Err(ExpressionError::InvalidArgumentType(fun, 0, arg));
}
}
Mf::Min | Mf::Max => {
let arg1_ty = match (arg1_ty, arg2_ty) {
(Some(ty1), None) => ty1,
_ => return Err(ExpressionError::WrongArgumentCount(fun)),
};
let good = match *arg_ty {
Ti::Scalar { kind, .. } | Ti::Vector { kind, .. } => kind != Sk::Bool,
_ => false,
};
if !good {
return Err(ExpressionError::InvalidArgumentType(fun, 0, arg));
}
if arg1_ty != arg_ty {
return Err(ExpressionError::InvalidArgumentType(
fun,
1,
arg1.unwrap(),
));
}
}
Mf::Clamp => {
let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty) {
(Some(ty1), Some(ty2)) => (ty1, ty2),
_ => return Err(ExpressionError::WrongArgumentCount(fun)),
};
let good = match *arg_ty {
Ti::Scalar { kind, .. } | Ti::Vector { kind, .. } => kind != Sk::Bool,
_ => false,
};
if !good {
return Err(ExpressionError::InvalidArgumentType(fun, 0, arg));
}
if arg1_ty != arg_ty {
return Err(ExpressionError::InvalidArgumentType(
fun,
1,
arg1.unwrap(),
));
}
if arg2_ty != arg_ty {
return Err(ExpressionError::InvalidArgumentType(
fun,
2,
arg2.unwrap(),
));
}
}
Mf::Cos
| Mf::Cosh
| Mf::Sin
| Mf::Sinh
| Mf::Tan
| Mf::Tanh
| Mf::Acos
| Mf::Asin
| Mf::Atan
| Mf::Ceil
| Mf::Floor
| Mf::Round
| Mf::Fract
| Mf::Trunc
| Mf::Exp
| Mf::Exp2
| Mf::Log
| Mf::Log2
| Mf::Length
| Mf::Sign
| Mf::Sqrt
| Mf::InverseSqrt => {
if arg1_ty.is_some() | arg2_ty.is_some() {
return Err(ExpressionError::WrongArgumentCount(fun));
}
match *arg_ty {
Ti::Scalar {
kind: Sk::Float, ..
}
| Ti::Vector {
kind: Sk::Float, ..
} => {}
_ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
}
}
Mf::Atan2 | Mf::Pow | Mf::Distance => {
let arg1_ty = match (arg1_ty, arg2_ty) {
(Some(ty1), None) => ty1,
_ => return Err(ExpressionError::WrongArgumentCount(fun)),
};
match *arg_ty {
Ti::Scalar {
kind: Sk::Float, ..
}
| Ti::Vector {
kind: Sk::Float, ..
} => {}
_ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
}
if arg1_ty != arg_ty {
return Err(ExpressionError::InvalidArgumentType(
fun,
1,
arg1.unwrap(),
));
}
}
Mf::Modf | Mf::Frexp | Mf::Ldexp => {
let arg1_ty = match (arg1_ty, arg2_ty) {
(Some(ty1), None) => ty1,
_ => return Err(ExpressionError::WrongArgumentCount(fun)),
};
let (size0, width0) = match *arg_ty {
Ti::Scalar {
kind: Sk::Float,
width,
} => (None, width),
Ti::Vector {
kind: Sk::Float,
size,
width,
} => (Some(size), width),
_ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
};
let good = match *arg1_ty {
Ti::Pointer { base, class: _ } => module.types[base].inner == *arg_ty,
Ti::ValuePointer {
size,
kind: Sk::Float,
width,
class: _,
} => size == size0 && width == width0,
_ => false,
};
if !good {
return Err(ExpressionError::InvalidArgumentType(
fun,
1,
arg1.unwrap(),
));
}
}
Mf::Dot | Mf::Outer | Mf::Cross | Mf::Step | Mf::Reflect => {
let arg1_ty = match (arg1_ty, arg2_ty) {
(Some(ty1), None) => ty1,
_ => return Err(ExpressionError::WrongArgumentCount(fun)),
};
match *arg_ty {
Ti::Vector {
kind: Sk::Float, ..
} => {}
_ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
}
if arg1_ty != arg_ty {
return Err(ExpressionError::InvalidArgumentType(
fun,
1,
arg1.unwrap(),
));
}
}
Mf::Normalize => {
if arg1_ty.is_some() | arg2_ty.is_some() {
return Err(ExpressionError::WrongArgumentCount(fun));
}
match *arg_ty {
Ti::Vector {
kind: Sk::Float, ..
} => {}
_ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
}
}
Mf::FaceForward | Mf::Fma | Mf::Mix | Mf::SmoothStep => {
let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty) {
(Some(ty1), Some(ty2)) => (ty1, ty2),
_ => return Err(ExpressionError::WrongArgumentCount(fun)),
};
match *arg_ty {
Ti::Scalar {
kind: Sk::Float, ..
}
| Ti::Vector {
kind: Sk::Float, ..
} => {}
_ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
}
if arg1_ty != arg_ty {
return Err(ExpressionError::InvalidArgumentType(
fun,
1,
arg1.unwrap(),
));
}
if arg2_ty != arg_ty {
return Err(ExpressionError::InvalidArgumentType(
fun,
2,
arg2.unwrap(),
));
}
}
Mf::Inverse | Mf::Determinant => {
if arg1_ty.is_some() | arg2_ty.is_some() {
return Err(ExpressionError::WrongArgumentCount(fun));
}
let good = match *arg_ty {
Ti::Matrix { columns, rows, .. } => columns == rows,
_ => false,
};
if !good {
return Err(ExpressionError::InvalidArgumentType(fun, 0, arg));
}
}
Mf::Transpose => {
if arg1_ty.is_some() | arg2_ty.is_some() {
return Err(ExpressionError::WrongArgumentCount(fun));
}
match *arg_ty {
Ti::Matrix { .. } => {}
_ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
}
}
Mf::CountOneBits | Mf::ReverseBits => {
if arg1_ty.is_some() | arg2_ty.is_some() {
return Err(ExpressionError::WrongArgumentCount(fun));
}
match *arg_ty {
Ti::Scalar { kind: Sk::Sint, .. }
| Ti::Scalar { kind: Sk::Uint, .. }
| Ti::Vector { kind: Sk::Sint, .. }
| Ti::Vector { kind: Sk::Uint, .. } => {}
_ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
}
}
}
ShaderStages::all()
}
E::As {
expr,
kind,
@@ -860,9 +1107,7 @@ impl super::Validator {
.resolve(expr)?
.scalar_kind()
.ok_or(ExpressionError::InvalidCastArgument)?;
if !convert && prev_kind == crate::ScalarKind::Bool
|| kind == crate::ScalarKind::Bool
{
if !convert && prev_kind == Sk::Bool || kind == Sk::Bool {
return Err(ExpressionError::InvalidCastArgument);
}
ShaderStages::all()