mirror of
https://github.com/gfx-rs/wgpu.git
synced 2026-04-22 03:02:01 -04:00
Validate math functions
This commit is contained in:
committed by
Dzmitry Malyshau
parent
50059740f7
commit
ca6876b7a0
@@ -226,6 +226,19 @@ fn parse_texture_load() {
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_texture_store() {
|
||||
parse_str(
|
||||
"
|
||||
var t: [[access(write)]] texture_storage_2d<rgba8unorm>;
|
||||
fn foo() {
|
||||
textureStore(t, vec2<i32>(10, 20), vec4<f32>(0.0, 1.0, 2.0, 3.0));
|
||||
}
|
||||
",
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_texture_query() {
|
||||
parse_str(
|
||||
|
||||
@@ -97,6 +97,10 @@ pub enum ExpressionError {
|
||||
InvalidSampleLevelGradientType(crate::ImageDimension, Handle<crate::Expression>),
|
||||
#[error("Unable to cast")]
|
||||
InvalidCastArgument,
|
||||
#[error("Invalid argument count for {0:?}")]
|
||||
WrongArgumentCount(crate::MathFunction),
|
||||
#[error("Argument [{1}] to {0:?} as expression {2:?} has an invalid type.")]
|
||||
InvalidArgumentType(crate::MathFunction, u32, Handle<crate::Expression>),
|
||||
}
|
||||
|
||||
struct ExpressionTypeResolver<'a> {
|
||||
@@ -401,7 +405,7 @@ impl super::Validator {
|
||||
if let Some(expr) = array_index {
|
||||
match *resolver.resolve(expr)? {
|
||||
Ti::Scalar {
|
||||
kind: crate::ScalarKind::Sint,
|
||||
kind: Sk::Sint,
|
||||
width: _,
|
||||
} => {}
|
||||
_ => return Err(ExpressionError::InvalidImageArrayIndexType(expr)),
|
||||
@@ -419,12 +423,11 @@ impl super::Validator {
|
||||
};
|
||||
match *resolver.resolve(coordinate)? {
|
||||
Ti::Scalar {
|
||||
kind: crate::ScalarKind::Float,
|
||||
..
|
||||
kind: Sk::Float, ..
|
||||
} if num_components == 1 => {}
|
||||
Ti::Vector {
|
||||
size,
|
||||
kind: crate::ScalarKind::Float,
|
||||
kind: Sk::Float,
|
||||
..
|
||||
} if size as u32 == num_components => {}
|
||||
_ => return Err(ExpressionError::InvalidImageCoordinateType(dim, coordinate)),
|
||||
@@ -440,7 +443,7 @@ impl super::Validator {
|
||||
match module.types[ty].inner {
|
||||
Ti::Vector {
|
||||
size,
|
||||
kind: crate::ScalarKind::Sint,
|
||||
kind: Sk::Sint,
|
||||
..
|
||||
} => size as u32 == num_components,
|
||||
_ => false,
|
||||
@@ -455,8 +458,7 @@ impl super::Validator {
|
||||
if let Some(expr) = depth_ref {
|
||||
match *resolver.resolve(expr)? {
|
||||
Ti::Scalar {
|
||||
kind: crate::ScalarKind::Float,
|
||||
..
|
||||
kind: Sk::Float, ..
|
||||
} => {}
|
||||
_ => return Err(ExpressionError::InvalidDepthReference(expr)),
|
||||
}
|
||||
@@ -474,8 +476,7 @@ impl super::Validator {
|
||||
crate::SampleLevel::Exact(expr) if can_level => {
|
||||
match *resolver.resolve(expr)? {
|
||||
Ti::Scalar {
|
||||
kind: crate::ScalarKind::Float,
|
||||
..
|
||||
kind: Sk::Float, ..
|
||||
} => {}
|
||||
_ => return Err(ExpressionError::InvalidSampleLevelExactType(expr)),
|
||||
}
|
||||
@@ -484,8 +485,7 @@ impl super::Validator {
|
||||
crate::SampleLevel::Bias(expr) => {
|
||||
match *resolver.resolve(expr)? {
|
||||
Ti::Scalar {
|
||||
kind: crate::ScalarKind::Float,
|
||||
..
|
||||
kind: Sk::Float, ..
|
||||
} => {}
|
||||
_ => return Err(ExpressionError::InvalidSampleLevelBiasType(expr)),
|
||||
}
|
||||
@@ -494,12 +494,11 @@ impl super::Validator {
|
||||
crate::SampleLevel::Gradient { x, y } => {
|
||||
match *resolver.resolve(x)? {
|
||||
Ti::Scalar {
|
||||
kind: crate::ScalarKind::Float,
|
||||
..
|
||||
kind: Sk::Float, ..
|
||||
} if num_components == 1 => {}
|
||||
Ti::Vector {
|
||||
size,
|
||||
kind: crate::ScalarKind::Float,
|
||||
kind: Sk::Float,
|
||||
..
|
||||
} if size as u32 == num_components => {}
|
||||
_ => {
|
||||
@@ -508,12 +507,11 @@ impl super::Validator {
|
||||
}
|
||||
match *resolver.resolve(y)? {
|
||||
Ti::Scalar {
|
||||
kind: crate::ScalarKind::Float,
|
||||
..
|
||||
kind: Sk::Float, ..
|
||||
} if num_components == 1 => {}
|
||||
Ti::Vector {
|
||||
size,
|
||||
kind: crate::ScalarKind::Float,
|
||||
kind: Sk::Float,
|
||||
..
|
||||
} if size as u32 == num_components => {}
|
||||
_ => {
|
||||
@@ -564,7 +562,7 @@ impl super::Validator {
|
||||
if let Some(expr) = array_index {
|
||||
match *resolver.resolve(expr)? {
|
||||
Ti::Scalar {
|
||||
kind: crate::ScalarKind::Sint,
|
||||
kind: Sk::Sint,
|
||||
width: _,
|
||||
} => {}
|
||||
_ => return Err(ExpressionError::InvalidImageArrayIndexType(expr)),
|
||||
@@ -573,7 +571,7 @@ impl super::Validator {
|
||||
if let Some(expr) = index {
|
||||
match *resolver.resolve(expr)? {
|
||||
Ti::Scalar {
|
||||
kind: crate::ScalarKind::Sint,
|
||||
kind: Sk::Sint,
|
||||
width: _,
|
||||
} => {}
|
||||
_ => return Err(ExpressionError::InvalidImageOtherIndexType(expr)),
|
||||
@@ -807,12 +805,10 @@ impl super::Validator {
|
||||
E::Derivative { axis: _, expr } => {
|
||||
match *resolver.resolve(expr)? {
|
||||
Ti::Scalar {
|
||||
kind: crate::ScalarKind::Float,
|
||||
..
|
||||
kind: Sk::Float, ..
|
||||
}
|
||||
| Ti::Vector {
|
||||
kind: crate::ScalarKind::Float,
|
||||
..
|
||||
kind: Sk::Float, ..
|
||||
} => {}
|
||||
_ => return Err(ExpressionError::InvalidDerivative),
|
||||
}
|
||||
@@ -844,13 +840,264 @@ impl super::Validator {
|
||||
}
|
||||
ShaderStages::all()
|
||||
}
|
||||
#[allow(unused)]
|
||||
E::Math {
|
||||
fun,
|
||||
arg,
|
||||
arg1,
|
||||
arg2,
|
||||
} => ShaderStages::all(),
|
||||
} => {
|
||||
use crate::MathFunction as Mf;
|
||||
|
||||
let arg_ty = resolver.resolve(arg)?;
|
||||
let arg1_ty = arg1.map(|expr| resolver.resolve(expr)).transpose()?;
|
||||
let arg2_ty = arg2.map(|expr| resolver.resolve(expr)).transpose()?;
|
||||
match fun {
|
||||
Mf::Abs => {
|
||||
if arg1_ty.is_some() | arg2_ty.is_some() {
|
||||
return Err(ExpressionError::WrongArgumentCount(fun));
|
||||
}
|
||||
let good = match *arg_ty {
|
||||
Ti::Scalar { kind, .. } | Ti::Vector { kind, .. } => kind != Sk::Bool,
|
||||
_ => false,
|
||||
};
|
||||
if !good {
|
||||
return Err(ExpressionError::InvalidArgumentType(fun, 0, arg));
|
||||
}
|
||||
}
|
||||
Mf::Min | Mf::Max => {
|
||||
let arg1_ty = match (arg1_ty, arg2_ty) {
|
||||
(Some(ty1), None) => ty1,
|
||||
_ => return Err(ExpressionError::WrongArgumentCount(fun)),
|
||||
};
|
||||
let good = match *arg_ty {
|
||||
Ti::Scalar { kind, .. } | Ti::Vector { kind, .. } => kind != Sk::Bool,
|
||||
_ => false,
|
||||
};
|
||||
if !good {
|
||||
return Err(ExpressionError::InvalidArgumentType(fun, 0, arg));
|
||||
}
|
||||
if arg1_ty != arg_ty {
|
||||
return Err(ExpressionError::InvalidArgumentType(
|
||||
fun,
|
||||
1,
|
||||
arg1.unwrap(),
|
||||
));
|
||||
}
|
||||
}
|
||||
Mf::Clamp => {
|
||||
let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty) {
|
||||
(Some(ty1), Some(ty2)) => (ty1, ty2),
|
||||
_ => return Err(ExpressionError::WrongArgumentCount(fun)),
|
||||
};
|
||||
let good = match *arg_ty {
|
||||
Ti::Scalar { kind, .. } | Ti::Vector { kind, .. } => kind != Sk::Bool,
|
||||
_ => false,
|
||||
};
|
||||
if !good {
|
||||
return Err(ExpressionError::InvalidArgumentType(fun, 0, arg));
|
||||
}
|
||||
if arg1_ty != arg_ty {
|
||||
return Err(ExpressionError::InvalidArgumentType(
|
||||
fun,
|
||||
1,
|
||||
arg1.unwrap(),
|
||||
));
|
||||
}
|
||||
if arg2_ty != arg_ty {
|
||||
return Err(ExpressionError::InvalidArgumentType(
|
||||
fun,
|
||||
2,
|
||||
arg2.unwrap(),
|
||||
));
|
||||
}
|
||||
}
|
||||
Mf::Cos
|
||||
| Mf::Cosh
|
||||
| Mf::Sin
|
||||
| Mf::Sinh
|
||||
| Mf::Tan
|
||||
| Mf::Tanh
|
||||
| Mf::Acos
|
||||
| Mf::Asin
|
||||
| Mf::Atan
|
||||
| Mf::Ceil
|
||||
| Mf::Floor
|
||||
| Mf::Round
|
||||
| Mf::Fract
|
||||
| Mf::Trunc
|
||||
| Mf::Exp
|
||||
| Mf::Exp2
|
||||
| Mf::Log
|
||||
| Mf::Log2
|
||||
| Mf::Length
|
||||
| Mf::Sign
|
||||
| Mf::Sqrt
|
||||
| Mf::InverseSqrt => {
|
||||
if arg1_ty.is_some() | arg2_ty.is_some() {
|
||||
return Err(ExpressionError::WrongArgumentCount(fun));
|
||||
}
|
||||
match *arg_ty {
|
||||
Ti::Scalar {
|
||||
kind: Sk::Float, ..
|
||||
}
|
||||
| Ti::Vector {
|
||||
kind: Sk::Float, ..
|
||||
} => {}
|
||||
_ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
|
||||
}
|
||||
}
|
||||
Mf::Atan2 | Mf::Pow | Mf::Distance => {
|
||||
let arg1_ty = match (arg1_ty, arg2_ty) {
|
||||
(Some(ty1), None) => ty1,
|
||||
_ => return Err(ExpressionError::WrongArgumentCount(fun)),
|
||||
};
|
||||
match *arg_ty {
|
||||
Ti::Scalar {
|
||||
kind: Sk::Float, ..
|
||||
}
|
||||
| Ti::Vector {
|
||||
kind: Sk::Float, ..
|
||||
} => {}
|
||||
_ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
|
||||
}
|
||||
if arg1_ty != arg_ty {
|
||||
return Err(ExpressionError::InvalidArgumentType(
|
||||
fun,
|
||||
1,
|
||||
arg1.unwrap(),
|
||||
));
|
||||
}
|
||||
}
|
||||
Mf::Modf | Mf::Frexp | Mf::Ldexp => {
|
||||
let arg1_ty = match (arg1_ty, arg2_ty) {
|
||||
(Some(ty1), None) => ty1,
|
||||
_ => return Err(ExpressionError::WrongArgumentCount(fun)),
|
||||
};
|
||||
let (size0, width0) = match *arg_ty {
|
||||
Ti::Scalar {
|
||||
kind: Sk::Float,
|
||||
width,
|
||||
} => (None, width),
|
||||
Ti::Vector {
|
||||
kind: Sk::Float,
|
||||
size,
|
||||
width,
|
||||
} => (Some(size), width),
|
||||
_ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
|
||||
};
|
||||
let good = match *arg1_ty {
|
||||
Ti::Pointer { base, class: _ } => module.types[base].inner == *arg_ty,
|
||||
Ti::ValuePointer {
|
||||
size,
|
||||
kind: Sk::Float,
|
||||
width,
|
||||
class: _,
|
||||
} => size == size0 && width == width0,
|
||||
_ => false,
|
||||
};
|
||||
if !good {
|
||||
return Err(ExpressionError::InvalidArgumentType(
|
||||
fun,
|
||||
1,
|
||||
arg1.unwrap(),
|
||||
));
|
||||
}
|
||||
}
|
||||
Mf::Dot | Mf::Outer | Mf::Cross | Mf::Step | Mf::Reflect => {
|
||||
let arg1_ty = match (arg1_ty, arg2_ty) {
|
||||
(Some(ty1), None) => ty1,
|
||||
_ => return Err(ExpressionError::WrongArgumentCount(fun)),
|
||||
};
|
||||
match *arg_ty {
|
||||
Ti::Vector {
|
||||
kind: Sk::Float, ..
|
||||
} => {}
|
||||
_ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
|
||||
}
|
||||
if arg1_ty != arg_ty {
|
||||
return Err(ExpressionError::InvalidArgumentType(
|
||||
fun,
|
||||
1,
|
||||
arg1.unwrap(),
|
||||
));
|
||||
}
|
||||
}
|
||||
Mf::Normalize => {
|
||||
if arg1_ty.is_some() | arg2_ty.is_some() {
|
||||
return Err(ExpressionError::WrongArgumentCount(fun));
|
||||
}
|
||||
match *arg_ty {
|
||||
Ti::Vector {
|
||||
kind: Sk::Float, ..
|
||||
} => {}
|
||||
_ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
|
||||
}
|
||||
}
|
||||
Mf::FaceForward | Mf::Fma | Mf::Mix | Mf::SmoothStep => {
|
||||
let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty) {
|
||||
(Some(ty1), Some(ty2)) => (ty1, ty2),
|
||||
_ => return Err(ExpressionError::WrongArgumentCount(fun)),
|
||||
};
|
||||
match *arg_ty {
|
||||
Ti::Scalar {
|
||||
kind: Sk::Float, ..
|
||||
}
|
||||
| Ti::Vector {
|
||||
kind: Sk::Float, ..
|
||||
} => {}
|
||||
_ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
|
||||
}
|
||||
if arg1_ty != arg_ty {
|
||||
return Err(ExpressionError::InvalidArgumentType(
|
||||
fun,
|
||||
1,
|
||||
arg1.unwrap(),
|
||||
));
|
||||
}
|
||||
if arg2_ty != arg_ty {
|
||||
return Err(ExpressionError::InvalidArgumentType(
|
||||
fun,
|
||||
2,
|
||||
arg2.unwrap(),
|
||||
));
|
||||
}
|
||||
}
|
||||
Mf::Inverse | Mf::Determinant => {
|
||||
if arg1_ty.is_some() | arg2_ty.is_some() {
|
||||
return Err(ExpressionError::WrongArgumentCount(fun));
|
||||
}
|
||||
let good = match *arg_ty {
|
||||
Ti::Matrix { columns, rows, .. } => columns == rows,
|
||||
_ => false,
|
||||
};
|
||||
if !good {
|
||||
return Err(ExpressionError::InvalidArgumentType(fun, 0, arg));
|
||||
}
|
||||
}
|
||||
Mf::Transpose => {
|
||||
if arg1_ty.is_some() | arg2_ty.is_some() {
|
||||
return Err(ExpressionError::WrongArgumentCount(fun));
|
||||
}
|
||||
match *arg_ty {
|
||||
Ti::Matrix { .. } => {}
|
||||
_ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
|
||||
}
|
||||
}
|
||||
Mf::CountOneBits | Mf::ReverseBits => {
|
||||
if arg1_ty.is_some() | arg2_ty.is_some() {
|
||||
return Err(ExpressionError::WrongArgumentCount(fun));
|
||||
}
|
||||
match *arg_ty {
|
||||
Ti::Scalar { kind: Sk::Sint, .. }
|
||||
| Ti::Scalar { kind: Sk::Uint, .. }
|
||||
| Ti::Vector { kind: Sk::Sint, .. }
|
||||
| Ti::Vector { kind: Sk::Uint, .. } => {}
|
||||
_ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
|
||||
}
|
||||
}
|
||||
}
|
||||
ShaderStages::all()
|
||||
}
|
||||
E::As {
|
||||
expr,
|
||||
kind,
|
||||
@@ -860,9 +1107,7 @@ impl super::Validator {
|
||||
.resolve(expr)?
|
||||
.scalar_kind()
|
||||
.ok_or(ExpressionError::InvalidCastArgument)?;
|
||||
if !convert && prev_kind == crate::ScalarKind::Bool
|
||||
|| kind == crate::ScalarKind::Bool
|
||||
{
|
||||
if !convert && prev_kind == Sk::Bool || kind == Sk::Bool {
|
||||
return Err(ExpressionError::InvalidCastArgument);
|
||||
}
|
||||
ShaderStages::all()
|
||||
|
||||
Reference in New Issue
Block a user