Validate image queries and valid shader stages for derivatives

This commit is contained in:
Dzmitry Malyshau
2021-03-23 22:48:02 -04:00
committed by Dzmitry Malyshau
parent ee43776c08
commit 7a246f6a14
8 changed files with 127 additions and 29 deletions

View File

@@ -120,7 +120,6 @@ pub enum ShaderStage {
#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
#[cfg_attr(feature = "serialize", derive(Serialize))]
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
#[allow(missing_docs)] // The names are self evident
pub enum StorageClass {
/// Function locals.
Function,

View File

@@ -6,7 +6,7 @@ Figures out the following properties:
- expression reference counts
!*/
use super::{CallError, ExpressionError, FunctionError, ModuleInfo, ValidationFlags};
use super::{CallError, ExpressionError, FunctionError, ModuleInfo, ShaderStages, ValidationFlags};
use crate::{
arena::{Arena, Handle},
proc::{ResolveContext, TypeResolution},
@@ -164,6 +164,8 @@ impl ExpressionInfo {
pub struct FunctionInfo {
/// Validation flags.
flags: ValidationFlags,
/// Set of shader stages where calling this function is valid.
pub available_stages: ShaderStages,
/// Uniformity characteristics.
pub uniformity: Uniformity,
/// Function may kill the invocation.
@@ -676,6 +678,7 @@ impl ModuleInfo {
) -> Result<FunctionInfo, FunctionError> {
let mut info = FunctionInfo {
flags,
available_stages: ShaderStages::all(),
uniformity: Uniformity::new(),
may_kill: false,
sampling_set: crate::FastHashSet::default(),
@@ -779,6 +782,7 @@ fn uniform_control_flow() {
let mut info = FunctionInfo {
flags: ValidationFlags::all(),
available_stages: ShaderStages::all(),
uniformity: Uniformity::new(),
may_kill: false,
sampling_set: crate::FastHashSet::default(),

View File

@@ -1,4 +1,4 @@
use super::FunctionInfo;
use super::{FunctionInfo, ShaderStages, TypeFlags};
use crate::{
arena::{Arena, Handle},
proc::ResolveError,
@@ -59,6 +59,12 @@ pub enum ExpressionError {
ExpectedGlobalVariable,
#[error("Calling an undeclared function {0:?}")]
CallToUndeclaredFunction(Handle<crate::Function>),
#[error("Needs to be an image instead of {0:?}")]
ExpectedImageType(Handle<crate::Type>),
#[error("Needs to be an image instead of {0:?}")]
ExpectedSamplerType(Handle<crate::Type>),
#[error("Unable to operate on image class {0:?}")]
InvalidImageClass(crate::ImageClass),
}
struct ExpressionTypeResolver<'a> {
@@ -88,7 +94,8 @@ impl super::Validator {
function: &crate::Function,
module: &crate::Module,
info: &FunctionInfo,
) -> Result<(), ExpressionError> {
other_infos: &[FunctionInfo],
) -> Result<ShaderStages, ExpressionError> {
use crate::{Expression as E, ScalarKind as Sk, TypeInner as Ti};
let resolver = ExpressionTypeResolver {
@@ -97,7 +104,7 @@ impl super::Validator {
info,
};
match *expression {
let stages = match *expression {
E::Access { base, index } => {
match *resolver.resolve(base)? {
Ti::Vector { .. }
@@ -124,6 +131,7 @@ impl super::Validator {
return Err(ExpressionError::InvalidIndexType(index));
}
}
ShaderStages::all()
}
E::AccessIndex { base, index } => {
let limit = match *resolver.resolve(base)? {
@@ -147,12 +155,14 @@ impl super::Validator {
if index >= limit {
return Err(ExpressionError::IndexOutOfBounds(base, index));
}
ShaderStages::all()
}
E::Constant(handle) => {
let _ = module
.constants
.try_get(handle)
.ok_or(ExpressionError::ConstantDoesntExist(handle))?;
ShaderStages::all()
}
E::Compose { ref components, ty } => {
match module
@@ -269,31 +279,42 @@ impl super::Validator {
return Err(ExpressionError::InvalidComposeType(ty));
}
}
ShaderStages::all()
}
E::FunctionArgument(index) => {
if index >= function.arguments.len() as u32 {
return Err(ExpressionError::FunctionArgumentDoesntExist(index));
}
ShaderStages::all()
}
E::GlobalVariable(handle) => {
let _ = module
.global_variables
.try_get(handle)
.ok_or(ExpressionError::GlobalVarDoesntExist(handle))?;
ShaderStages::all()
}
E::LocalVariable(handle) => {
let _ = function
.local_variables
.try_get(handle)
.ok_or(ExpressionError::LocalVarDoesntExist(handle))?;
ShaderStages::all()
}
E::Load { pointer } => match *resolver.resolve(pointer)? {
Ti::Pointer { .. } | Ti::ValuePointer { .. } => {}
ref other => {
log::error!("Loading {:?}", other);
return Err(ExpressionError::InvalidPointerType(pointer));
E::Load { pointer } => {
match *resolver.resolve(pointer)? {
Ti::Pointer { base, .. }
if self.types[base.index()]
.flags
.contains(TypeFlags::SIZED | TypeFlags::DATA) => {}
Ti::ValuePointer { .. } => {}
ref other => {
log::error!("Loading {:?}", other);
return Err(ExpressionError::InvalidPointerType(pointer));
}
}
},
ShaderStages::all()
}
#[allow(unused)]
E::ImageSample {
image,
@@ -303,16 +324,43 @@ impl super::Validator {
offset,
level,
depth_ref,
} => {}
} => ShaderStages::all(),
#[allow(unused)]
E::ImageLoad {
image,
coordinate,
array_index,
index,
} => {}
#[allow(unused)]
E::ImageQuery { image, query } => {}
} => ShaderStages::all(),
E::ImageQuery { image, query } => {
match function.expressions[image] {
crate::Expression::GlobalVariable(var_handle) => {
let var = &module.global_variables[var_handle];
match module.types[var.ty].inner {
Ti::Image { class, arrayed, .. } => {
let can_level = match class {
crate::ImageClass::Sampled { multi, .. } => !multi,
crate::ImageClass::Storage { .. } => false,
crate::ImageClass::Depth { .. } => true,
};
let good = match query {
crate::ImageQuery::NumLayers => arrayed,
crate::ImageQuery::Size { level: Some(_) }
| crate::ImageQuery::NumLevels => can_level,
crate::ImageQuery::Size { level: None }
| crate::ImageQuery::NumSamples => !can_level,
};
if !good {
return Err(ExpressionError::InvalidImageClass(class));
}
}
_ => return Err(ExpressionError::ExpectedImageType(var.ty)),
}
}
_ => return Err(ExpressionError::ExpectedGlobalVariable),
}
ShaderStages::all()
}
E::Unary { op, expr } => {
use crate::UnaryOperator as Uo;
let inner = resolver.resolve(expr)?;
@@ -326,6 +374,7 @@ impl super::Validator {
return Err(ExpressionError::InvalidUnaryOperandType(op, expr));
}
}
ShaderStages::all()
}
E::Binary { op, left, right } => {
use crate::BinaryOperator as Bo;
@@ -472,6 +521,7 @@ impl super::Validator {
if !good {
return Err(ExpressionError::InvalidBinaryOperandTypes(op, left, right));
}
ShaderStages::all()
}
E::Select {
condition,
@@ -500,11 +550,10 @@ impl super::Validator {
if !condition_good || accept_inner != reject_inner {
return Err(ExpressionError::InvalidSelectTypes);
}
ShaderStages::all()
}
#[allow(unused)]
E::Derivative { axis, expr } => {
//TODO: check stage
}
E::Derivative { axis, expr } => ShaderStages::FRAGMENT,
E::Relational { fun, argument } => {
use crate::RelationalFunction as Rf;
let argument_inner = resolver.resolve(argument)?;
@@ -529,6 +578,7 @@ impl super::Validator {
}
},
}
ShaderStages::all()
}
#[allow(unused)]
E::Math {
@@ -536,23 +586,22 @@ impl super::Validator {
arg,
arg1,
arg2,
} => {}
} => ShaderStages::all(),
#[allow(unused)]
E::As {
expr,
kind,
convert,
} => {}
#[allow(unused)]
E::Call(function) => {}
} => ShaderStages::all(),
E::Call(function) => other_infos[function.index()].available_stages,
E::ArrayLength(expr) => match *resolver.resolve(expr)? {
Ti::Array { .. } => {}
Ti::Array { .. } => ShaderStages::all(),
ref other => {
log::error!("Array length of {:?}", other);
return Err(ExpressionError::InvalidArrayType(expr));
}
},
}
Ok(())
};
Ok(stages)
}
}

View File

@@ -453,7 +453,7 @@ impl super::Validator {
module: &crate::Module,
mod_info: &ModuleInfo,
) -> Result<FunctionInfo, FunctionError> {
let info = mod_info.process_function(fun, module, self.flags)?;
let mut info = mod_info.process_function(fun, module, self.flags)?;
for (var_handle, var) in fun.local_variables.iter() {
self.validate_local_var(var, &module.types, &module.constants)
@@ -482,8 +482,16 @@ impl super::Validator {
self.valid_expression_set.insert(handle.index());
}
if !self.flags.contains(ValidationFlags::EXPRESSIONS) {
if let Err(error) = self.validate_expression(handle, expr, fun, module, &info) {
return Err(FunctionError::Expression { handle, error });
match self.validate_expression(
handle,
expr,
fun,
module,
&info,
&mod_info.functions,
) {
Ok(stages) => info.available_stages &= stages,
Err(error) => return Err(FunctionError::Expression { handle, error }),
}
}
}

View File

@@ -1,6 +1,6 @@
use super::{
analyzer::{FunctionInfo, GlobalUse},
Disalignment, FunctionError, ModuleInfo, TypeFlags,
Disalignment, FunctionError, ModuleInfo, ShaderStages, TypeFlags,
};
use crate::arena::{Arena, Handle};
@@ -56,6 +56,8 @@ pub enum EntryPointError {
UnexpectedWorkgroupSize,
#[error("Workgroup size is out of range")]
OutOfRangeWorkgroupSize,
#[error("Uses operations forbidden at this stage")]
ForbiddenStageOperations,
#[error("Global variable {0:?} is used incorrectly as {1:?}")]
InvalidGlobalUsage(Handle<crate::GlobalVariable>, GlobalUse),
#[error("Bindings for {0:?} conflict with other resource")]
@@ -370,8 +372,18 @@ impl super::Validator {
return Err(EntryPointError::UnexpectedWorkgroupSize);
}
let stage_bit = match ep.stage {
crate::ShaderStage::Vertex => ShaderStages::VERTEX,
crate::ShaderStage::Fragment => ShaderStages::FRAGMENT,
crate::ShaderStage::Compute => ShaderStages::COMPUTE,
};
let info = self.validate_function(&ep.function, module, &mod_info)?;
if !info.available_stages.contains(stage_bit) {
return Err(EntryPointError::ForbiddenStageOperations);
}
self.location_mask.clear();
for (index, fa) in ep.function.arguments.iter().enumerate() {
let ctx = VaryingContext {

View File

@@ -32,6 +32,17 @@ bitflags::bitflags! {
}
}
bitflags::bitflags! {
/// Validation flags.
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
pub struct ShaderStages: u8 {
const VERTEX = 0x1;
const FRAGMENT = 0x2;
const COMPUTE = 0x4;
}
}
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
pub struct ModuleInfo {

View File

@@ -8,6 +8,9 @@ expression: output
flags: (
bits: 7,
),
available_stages: (
bits: 7,
),
uniformity: (
non_uniform_result: Some(5),
requirements: (
@@ -350,6 +353,9 @@ expression: output
flags: (
bits: 7,
),
available_stages: (
bits: 7,
),
uniformity: (
non_uniform_result: Some(5),
requirements: (

View File

@@ -8,6 +8,9 @@ expression: output
flags: (
bits: 7,
),
available_stages: (
bits: 7,
),
uniformity: (
non_uniform_result: Some(44),
requirements: (
@@ -1006,6 +1009,9 @@ expression: output
flags: (
bits: 7,
),
available_stages: (
bits: 7,
),
uniformity: (
non_uniform_result: Some(44),
requirements: (
@@ -2634,6 +2640,9 @@ expression: output
flags: (
bits: 7,
),
available_stages: (
bits: 7,
),
uniformity: (
non_uniform_result: Some(44),
requirements: (