From 7a246f6a14feefee871e24fa8ead2da1cde70abe Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Tue, 23 Mar 2021 22:48:02 -0400 Subject: [PATCH] Validate image queries and valid shader stages for derivatives --- src/lib.rs | 1 - src/valid/analyzer.rs | 6 ++- src/valid/expression.rs | 95 +++++++++++++++++++++++++-------- src/valid/function.rs | 14 +++-- src/valid/interface.rs | 14 ++++- src/valid/mod.rs | 11 ++++ tests/out/collatz.info.ron.snap | 6 +++ tests/out/shadow.info.ron.snap | 9 ++++ 8 files changed, 127 insertions(+), 29 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 4f1992d3d9..8d2590f3fd 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -120,7 +120,6 @@ pub enum ShaderStage { #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] -#[allow(missing_docs)] // The names are self evident pub enum StorageClass { /// Function locals. Function, diff --git a/src/valid/analyzer.rs b/src/valid/analyzer.rs index 7175e94625..9e8e81c91b 100644 --- a/src/valid/analyzer.rs +++ b/src/valid/analyzer.rs @@ -6,7 +6,7 @@ Figures out the following properties: - expression reference counts !*/ -use super::{CallError, ExpressionError, FunctionError, ModuleInfo, ValidationFlags}; +use super::{CallError, ExpressionError, FunctionError, ModuleInfo, ShaderStages, ValidationFlags}; use crate::{ arena::{Arena, Handle}, proc::{ResolveContext, TypeResolution}, @@ -164,6 +164,8 @@ impl ExpressionInfo { pub struct FunctionInfo { /// Validation flags. flags: ValidationFlags, + /// Set of shader stages where calling this function is valid. + pub available_stages: ShaderStages, /// Uniformity characteristics. pub uniformity: Uniformity, /// Function may kill the invocation. @@ -676,6 +678,7 @@ impl ModuleInfo { ) -> Result { let mut info = FunctionInfo { flags, + available_stages: ShaderStages::all(), uniformity: Uniformity::new(), may_kill: false, sampling_set: crate::FastHashSet::default(), @@ -779,6 +782,7 @@ fn uniform_control_flow() { let mut info = FunctionInfo { flags: ValidationFlags::all(), + available_stages: ShaderStages::all(), uniformity: Uniformity::new(), may_kill: false, sampling_set: crate::FastHashSet::default(), diff --git a/src/valid/expression.rs b/src/valid/expression.rs index 457a867891..fbcd5b67e2 100644 --- a/src/valid/expression.rs +++ b/src/valid/expression.rs @@ -1,4 +1,4 @@ -use super::FunctionInfo; +use super::{FunctionInfo, ShaderStages, TypeFlags}; use crate::{ arena::{Arena, Handle}, proc::ResolveError, @@ -59,6 +59,12 @@ pub enum ExpressionError { ExpectedGlobalVariable, #[error("Calling an undeclared function {0:?}")] CallToUndeclaredFunction(Handle), + #[error("Needs to be an image instead of {0:?}")] + ExpectedImageType(Handle), + #[error("Needs to be an image instead of {0:?}")] + ExpectedSamplerType(Handle), + #[error("Unable to operate on image class {0:?}")] + InvalidImageClass(crate::ImageClass), } struct ExpressionTypeResolver<'a> { @@ -88,7 +94,8 @@ impl super::Validator { function: &crate::Function, module: &crate::Module, info: &FunctionInfo, - ) -> Result<(), ExpressionError> { + other_infos: &[FunctionInfo], + ) -> Result { use crate::{Expression as E, ScalarKind as Sk, TypeInner as Ti}; let resolver = ExpressionTypeResolver { @@ -97,7 +104,7 @@ impl super::Validator { info, }; - match *expression { + let stages = match *expression { E::Access { base, index } => { match *resolver.resolve(base)? { Ti::Vector { .. } @@ -124,6 +131,7 @@ impl super::Validator { return Err(ExpressionError::InvalidIndexType(index)); } } + ShaderStages::all() } E::AccessIndex { base, index } => { let limit = match *resolver.resolve(base)? { @@ -147,12 +155,14 @@ impl super::Validator { if index >= limit { return Err(ExpressionError::IndexOutOfBounds(base, index)); } + ShaderStages::all() } E::Constant(handle) => { let _ = module .constants .try_get(handle) .ok_or(ExpressionError::ConstantDoesntExist(handle))?; + ShaderStages::all() } E::Compose { ref components, ty } => { match module @@ -269,31 +279,42 @@ impl super::Validator { return Err(ExpressionError::InvalidComposeType(ty)); } } + ShaderStages::all() } E::FunctionArgument(index) => { if index >= function.arguments.len() as u32 { return Err(ExpressionError::FunctionArgumentDoesntExist(index)); } + ShaderStages::all() } E::GlobalVariable(handle) => { let _ = module .global_variables .try_get(handle) .ok_or(ExpressionError::GlobalVarDoesntExist(handle))?; + ShaderStages::all() } E::LocalVariable(handle) => { let _ = function .local_variables .try_get(handle) .ok_or(ExpressionError::LocalVarDoesntExist(handle))?; + ShaderStages::all() } - E::Load { pointer } => match *resolver.resolve(pointer)? { - Ti::Pointer { .. } | Ti::ValuePointer { .. } => {} - ref other => { - log::error!("Loading {:?}", other); - return Err(ExpressionError::InvalidPointerType(pointer)); + E::Load { pointer } => { + match *resolver.resolve(pointer)? { + Ti::Pointer { base, .. } + if self.types[base.index()] + .flags + .contains(TypeFlags::SIZED | TypeFlags::DATA) => {} + Ti::ValuePointer { .. } => {} + ref other => { + log::error!("Loading {:?}", other); + return Err(ExpressionError::InvalidPointerType(pointer)); + } } - }, + ShaderStages::all() + } #[allow(unused)] E::ImageSample { image, @@ -303,16 +324,43 @@ impl super::Validator { offset, level, depth_ref, - } => {} + } => ShaderStages::all(), #[allow(unused)] E::ImageLoad { image, coordinate, array_index, index, - } => {} - #[allow(unused)] - E::ImageQuery { image, query } => {} + } => ShaderStages::all(), + E::ImageQuery { image, query } => { + match function.expressions[image] { + crate::Expression::GlobalVariable(var_handle) => { + let var = &module.global_variables[var_handle]; + match module.types[var.ty].inner { + Ti::Image { class, arrayed, .. } => { + let can_level = match class { + crate::ImageClass::Sampled { multi, .. } => !multi, + crate::ImageClass::Storage { .. } => false, + crate::ImageClass::Depth { .. } => true, + }; + let good = match query { + crate::ImageQuery::NumLayers => arrayed, + crate::ImageQuery::Size { level: Some(_) } + | crate::ImageQuery::NumLevels => can_level, + crate::ImageQuery::Size { level: None } + | crate::ImageQuery::NumSamples => !can_level, + }; + if !good { + return Err(ExpressionError::InvalidImageClass(class)); + } + } + _ => return Err(ExpressionError::ExpectedImageType(var.ty)), + } + } + _ => return Err(ExpressionError::ExpectedGlobalVariable), + } + ShaderStages::all() + } E::Unary { op, expr } => { use crate::UnaryOperator as Uo; let inner = resolver.resolve(expr)?; @@ -326,6 +374,7 @@ impl super::Validator { return Err(ExpressionError::InvalidUnaryOperandType(op, expr)); } } + ShaderStages::all() } E::Binary { op, left, right } => { use crate::BinaryOperator as Bo; @@ -472,6 +521,7 @@ impl super::Validator { if !good { return Err(ExpressionError::InvalidBinaryOperandTypes(op, left, right)); } + ShaderStages::all() } E::Select { condition, @@ -500,11 +550,10 @@ impl super::Validator { if !condition_good || accept_inner != reject_inner { return Err(ExpressionError::InvalidSelectTypes); } + ShaderStages::all() } #[allow(unused)] - E::Derivative { axis, expr } => { - //TODO: check stage - } + E::Derivative { axis, expr } => ShaderStages::FRAGMENT, E::Relational { fun, argument } => { use crate::RelationalFunction as Rf; let argument_inner = resolver.resolve(argument)?; @@ -529,6 +578,7 @@ impl super::Validator { } }, } + ShaderStages::all() } #[allow(unused)] E::Math { @@ -536,23 +586,22 @@ impl super::Validator { arg, arg1, arg2, - } => {} + } => ShaderStages::all(), #[allow(unused)] E::As { expr, kind, convert, - } => {} - #[allow(unused)] - E::Call(function) => {} + } => ShaderStages::all(), + E::Call(function) => other_infos[function.index()].available_stages, E::ArrayLength(expr) => match *resolver.resolve(expr)? { - Ti::Array { .. } => {} + Ti::Array { .. } => ShaderStages::all(), ref other => { log::error!("Array length of {:?}", other); return Err(ExpressionError::InvalidArrayType(expr)); } }, - } - Ok(()) + }; + Ok(stages) } } diff --git a/src/valid/function.rs b/src/valid/function.rs index 88f2a2795c..50653a78d9 100644 --- a/src/valid/function.rs +++ b/src/valid/function.rs @@ -453,7 +453,7 @@ impl super::Validator { module: &crate::Module, mod_info: &ModuleInfo, ) -> Result { - let info = mod_info.process_function(fun, module, self.flags)?; + let mut info = mod_info.process_function(fun, module, self.flags)?; for (var_handle, var) in fun.local_variables.iter() { self.validate_local_var(var, &module.types, &module.constants) @@ -482,8 +482,16 @@ impl super::Validator { self.valid_expression_set.insert(handle.index()); } if !self.flags.contains(ValidationFlags::EXPRESSIONS) { - if let Err(error) = self.validate_expression(handle, expr, fun, module, &info) { - return Err(FunctionError::Expression { handle, error }); + match self.validate_expression( + handle, + expr, + fun, + module, + &info, + &mod_info.functions, + ) { + Ok(stages) => info.available_stages &= stages, + Err(error) => return Err(FunctionError::Expression { handle, error }), } } } diff --git a/src/valid/interface.rs b/src/valid/interface.rs index b6fe1a5a4b..ee54a9f86a 100644 --- a/src/valid/interface.rs +++ b/src/valid/interface.rs @@ -1,6 +1,6 @@ use super::{ analyzer::{FunctionInfo, GlobalUse}, - Disalignment, FunctionError, ModuleInfo, TypeFlags, + Disalignment, FunctionError, ModuleInfo, ShaderStages, TypeFlags, }; use crate::arena::{Arena, Handle}; @@ -56,6 +56,8 @@ pub enum EntryPointError { UnexpectedWorkgroupSize, #[error("Workgroup size is out of range")] OutOfRangeWorkgroupSize, + #[error("Uses operations forbidden at this stage")] + ForbiddenStageOperations, #[error("Global variable {0:?} is used incorrectly as {1:?}")] InvalidGlobalUsage(Handle, GlobalUse), #[error("Bindings for {0:?} conflict with other resource")] @@ -370,8 +372,18 @@ impl super::Validator { return Err(EntryPointError::UnexpectedWorkgroupSize); } + let stage_bit = match ep.stage { + crate::ShaderStage::Vertex => ShaderStages::VERTEX, + crate::ShaderStage::Fragment => ShaderStages::FRAGMENT, + crate::ShaderStage::Compute => ShaderStages::COMPUTE, + }; + let info = self.validate_function(&ep.function, module, &mod_info)?; + if !info.available_stages.contains(stage_bit) { + return Err(EntryPointError::ForbiddenStageOperations); + } + self.location_mask.clear(); for (index, fa) in ep.function.arguments.iter().enumerate() { let ctx = VaryingContext { diff --git a/src/valid/mod.rs b/src/valid/mod.rs index 615b2c1e50..aff2714756 100644 --- a/src/valid/mod.rs +++ b/src/valid/mod.rs @@ -32,6 +32,17 @@ bitflags::bitflags! { } } +bitflags::bitflags! { + /// Validation flags. + #[cfg_attr(feature = "serialize", derive(serde::Serialize))] + #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] + pub struct ShaderStages: u8 { + const VERTEX = 0x1; + const FRAGMENT = 0x2; + const COMPUTE = 0x4; + } +} + #[cfg_attr(feature = "serialize", derive(serde::Serialize))] #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] pub struct ModuleInfo { diff --git a/tests/out/collatz.info.ron.snap b/tests/out/collatz.info.ron.snap index 7a69490949..77dcd1958e 100644 --- a/tests/out/collatz.info.ron.snap +++ b/tests/out/collatz.info.ron.snap @@ -8,6 +8,9 @@ expression: output flags: ( bits: 7, ), + available_stages: ( + bits: 7, + ), uniformity: ( non_uniform_result: Some(5), requirements: ( @@ -350,6 +353,9 @@ expression: output flags: ( bits: 7, ), + available_stages: ( + bits: 7, + ), uniformity: ( non_uniform_result: Some(5), requirements: ( diff --git a/tests/out/shadow.info.ron.snap b/tests/out/shadow.info.ron.snap index 817820f840..622aa4b80c 100644 --- a/tests/out/shadow.info.ron.snap +++ b/tests/out/shadow.info.ron.snap @@ -8,6 +8,9 @@ expression: output flags: ( bits: 7, ), + available_stages: ( + bits: 7, + ), uniformity: ( non_uniform_result: Some(44), requirements: ( @@ -1006,6 +1009,9 @@ expression: output flags: ( bits: 7, ), + available_stages: ( + bits: 7, + ), uniformity: ( non_uniform_result: Some(44), requirements: ( @@ -2634,6 +2640,9 @@ expression: output flags: ( bits: 7, ), + available_stages: ( + bits: 7, + ), uniformity: ( non_uniform_result: Some(44), requirements: (