mirror of
https://github.com/gfx-rs/wgpu.git
synced 2026-04-22 03:02:01 -04:00
Validate image queries and valid shader stages for derivatives
This commit is contained in:
committed by
Dzmitry Malyshau
parent
ee43776c08
commit
7a246f6a14
@@ -120,7 +120,6 @@ pub enum ShaderStage {
|
||||
#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
|
||||
#[cfg_attr(feature = "serialize", derive(Serialize))]
|
||||
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
|
||||
#[allow(missing_docs)] // The names are self evident
|
||||
pub enum StorageClass {
|
||||
/// Function locals.
|
||||
Function,
|
||||
|
||||
@@ -6,7 +6,7 @@ Figures out the following properties:
|
||||
- expression reference counts
|
||||
!*/
|
||||
|
||||
use super::{CallError, ExpressionError, FunctionError, ModuleInfo, ValidationFlags};
|
||||
use super::{CallError, ExpressionError, FunctionError, ModuleInfo, ShaderStages, ValidationFlags};
|
||||
use crate::{
|
||||
arena::{Arena, Handle},
|
||||
proc::{ResolveContext, TypeResolution},
|
||||
@@ -164,6 +164,8 @@ impl ExpressionInfo {
|
||||
pub struct FunctionInfo {
|
||||
/// Validation flags.
|
||||
flags: ValidationFlags,
|
||||
/// Set of shader stages where calling this function is valid.
|
||||
pub available_stages: ShaderStages,
|
||||
/// Uniformity characteristics.
|
||||
pub uniformity: Uniformity,
|
||||
/// Function may kill the invocation.
|
||||
@@ -676,6 +678,7 @@ impl ModuleInfo {
|
||||
) -> Result<FunctionInfo, FunctionError> {
|
||||
let mut info = FunctionInfo {
|
||||
flags,
|
||||
available_stages: ShaderStages::all(),
|
||||
uniformity: Uniformity::new(),
|
||||
may_kill: false,
|
||||
sampling_set: crate::FastHashSet::default(),
|
||||
@@ -779,6 +782,7 @@ fn uniform_control_flow() {
|
||||
|
||||
let mut info = FunctionInfo {
|
||||
flags: ValidationFlags::all(),
|
||||
available_stages: ShaderStages::all(),
|
||||
uniformity: Uniformity::new(),
|
||||
may_kill: false,
|
||||
sampling_set: crate::FastHashSet::default(),
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use super::FunctionInfo;
|
||||
use super::{FunctionInfo, ShaderStages, TypeFlags};
|
||||
use crate::{
|
||||
arena::{Arena, Handle},
|
||||
proc::ResolveError,
|
||||
@@ -59,6 +59,12 @@ pub enum ExpressionError {
|
||||
ExpectedGlobalVariable,
|
||||
#[error("Calling an undeclared function {0:?}")]
|
||||
CallToUndeclaredFunction(Handle<crate::Function>),
|
||||
#[error("Needs to be an image instead of {0:?}")]
|
||||
ExpectedImageType(Handle<crate::Type>),
|
||||
#[error("Needs to be an image instead of {0:?}")]
|
||||
ExpectedSamplerType(Handle<crate::Type>),
|
||||
#[error("Unable to operate on image class {0:?}")]
|
||||
InvalidImageClass(crate::ImageClass),
|
||||
}
|
||||
|
||||
struct ExpressionTypeResolver<'a> {
|
||||
@@ -88,7 +94,8 @@ impl super::Validator {
|
||||
function: &crate::Function,
|
||||
module: &crate::Module,
|
||||
info: &FunctionInfo,
|
||||
) -> Result<(), ExpressionError> {
|
||||
other_infos: &[FunctionInfo],
|
||||
) -> Result<ShaderStages, ExpressionError> {
|
||||
use crate::{Expression as E, ScalarKind as Sk, TypeInner as Ti};
|
||||
|
||||
let resolver = ExpressionTypeResolver {
|
||||
@@ -97,7 +104,7 @@ impl super::Validator {
|
||||
info,
|
||||
};
|
||||
|
||||
match *expression {
|
||||
let stages = match *expression {
|
||||
E::Access { base, index } => {
|
||||
match *resolver.resolve(base)? {
|
||||
Ti::Vector { .. }
|
||||
@@ -124,6 +131,7 @@ impl super::Validator {
|
||||
return Err(ExpressionError::InvalidIndexType(index));
|
||||
}
|
||||
}
|
||||
ShaderStages::all()
|
||||
}
|
||||
E::AccessIndex { base, index } => {
|
||||
let limit = match *resolver.resolve(base)? {
|
||||
@@ -147,12 +155,14 @@ impl super::Validator {
|
||||
if index >= limit {
|
||||
return Err(ExpressionError::IndexOutOfBounds(base, index));
|
||||
}
|
||||
ShaderStages::all()
|
||||
}
|
||||
E::Constant(handle) => {
|
||||
let _ = module
|
||||
.constants
|
||||
.try_get(handle)
|
||||
.ok_or(ExpressionError::ConstantDoesntExist(handle))?;
|
||||
ShaderStages::all()
|
||||
}
|
||||
E::Compose { ref components, ty } => {
|
||||
match module
|
||||
@@ -269,31 +279,42 @@ impl super::Validator {
|
||||
return Err(ExpressionError::InvalidComposeType(ty));
|
||||
}
|
||||
}
|
||||
ShaderStages::all()
|
||||
}
|
||||
E::FunctionArgument(index) => {
|
||||
if index >= function.arguments.len() as u32 {
|
||||
return Err(ExpressionError::FunctionArgumentDoesntExist(index));
|
||||
}
|
||||
ShaderStages::all()
|
||||
}
|
||||
E::GlobalVariable(handle) => {
|
||||
let _ = module
|
||||
.global_variables
|
||||
.try_get(handle)
|
||||
.ok_or(ExpressionError::GlobalVarDoesntExist(handle))?;
|
||||
ShaderStages::all()
|
||||
}
|
||||
E::LocalVariable(handle) => {
|
||||
let _ = function
|
||||
.local_variables
|
||||
.try_get(handle)
|
||||
.ok_or(ExpressionError::LocalVarDoesntExist(handle))?;
|
||||
ShaderStages::all()
|
||||
}
|
||||
E::Load { pointer } => match *resolver.resolve(pointer)? {
|
||||
Ti::Pointer { .. } | Ti::ValuePointer { .. } => {}
|
||||
ref other => {
|
||||
log::error!("Loading {:?}", other);
|
||||
return Err(ExpressionError::InvalidPointerType(pointer));
|
||||
E::Load { pointer } => {
|
||||
match *resolver.resolve(pointer)? {
|
||||
Ti::Pointer { base, .. }
|
||||
if self.types[base.index()]
|
||||
.flags
|
||||
.contains(TypeFlags::SIZED | TypeFlags::DATA) => {}
|
||||
Ti::ValuePointer { .. } => {}
|
||||
ref other => {
|
||||
log::error!("Loading {:?}", other);
|
||||
return Err(ExpressionError::InvalidPointerType(pointer));
|
||||
}
|
||||
}
|
||||
},
|
||||
ShaderStages::all()
|
||||
}
|
||||
#[allow(unused)]
|
||||
E::ImageSample {
|
||||
image,
|
||||
@@ -303,16 +324,43 @@ impl super::Validator {
|
||||
offset,
|
||||
level,
|
||||
depth_ref,
|
||||
} => {}
|
||||
} => ShaderStages::all(),
|
||||
#[allow(unused)]
|
||||
E::ImageLoad {
|
||||
image,
|
||||
coordinate,
|
||||
array_index,
|
||||
index,
|
||||
} => {}
|
||||
#[allow(unused)]
|
||||
E::ImageQuery { image, query } => {}
|
||||
} => ShaderStages::all(),
|
||||
E::ImageQuery { image, query } => {
|
||||
match function.expressions[image] {
|
||||
crate::Expression::GlobalVariable(var_handle) => {
|
||||
let var = &module.global_variables[var_handle];
|
||||
match module.types[var.ty].inner {
|
||||
Ti::Image { class, arrayed, .. } => {
|
||||
let can_level = match class {
|
||||
crate::ImageClass::Sampled { multi, .. } => !multi,
|
||||
crate::ImageClass::Storage { .. } => false,
|
||||
crate::ImageClass::Depth { .. } => true,
|
||||
};
|
||||
let good = match query {
|
||||
crate::ImageQuery::NumLayers => arrayed,
|
||||
crate::ImageQuery::Size { level: Some(_) }
|
||||
| crate::ImageQuery::NumLevels => can_level,
|
||||
crate::ImageQuery::Size { level: None }
|
||||
| crate::ImageQuery::NumSamples => !can_level,
|
||||
};
|
||||
if !good {
|
||||
return Err(ExpressionError::InvalidImageClass(class));
|
||||
}
|
||||
}
|
||||
_ => return Err(ExpressionError::ExpectedImageType(var.ty)),
|
||||
}
|
||||
}
|
||||
_ => return Err(ExpressionError::ExpectedGlobalVariable),
|
||||
}
|
||||
ShaderStages::all()
|
||||
}
|
||||
E::Unary { op, expr } => {
|
||||
use crate::UnaryOperator as Uo;
|
||||
let inner = resolver.resolve(expr)?;
|
||||
@@ -326,6 +374,7 @@ impl super::Validator {
|
||||
return Err(ExpressionError::InvalidUnaryOperandType(op, expr));
|
||||
}
|
||||
}
|
||||
ShaderStages::all()
|
||||
}
|
||||
E::Binary { op, left, right } => {
|
||||
use crate::BinaryOperator as Bo;
|
||||
@@ -472,6 +521,7 @@ impl super::Validator {
|
||||
if !good {
|
||||
return Err(ExpressionError::InvalidBinaryOperandTypes(op, left, right));
|
||||
}
|
||||
ShaderStages::all()
|
||||
}
|
||||
E::Select {
|
||||
condition,
|
||||
@@ -500,11 +550,10 @@ impl super::Validator {
|
||||
if !condition_good || accept_inner != reject_inner {
|
||||
return Err(ExpressionError::InvalidSelectTypes);
|
||||
}
|
||||
ShaderStages::all()
|
||||
}
|
||||
#[allow(unused)]
|
||||
E::Derivative { axis, expr } => {
|
||||
//TODO: check stage
|
||||
}
|
||||
E::Derivative { axis, expr } => ShaderStages::FRAGMENT,
|
||||
E::Relational { fun, argument } => {
|
||||
use crate::RelationalFunction as Rf;
|
||||
let argument_inner = resolver.resolve(argument)?;
|
||||
@@ -529,6 +578,7 @@ impl super::Validator {
|
||||
}
|
||||
},
|
||||
}
|
||||
ShaderStages::all()
|
||||
}
|
||||
#[allow(unused)]
|
||||
E::Math {
|
||||
@@ -536,23 +586,22 @@ impl super::Validator {
|
||||
arg,
|
||||
arg1,
|
||||
arg2,
|
||||
} => {}
|
||||
} => ShaderStages::all(),
|
||||
#[allow(unused)]
|
||||
E::As {
|
||||
expr,
|
||||
kind,
|
||||
convert,
|
||||
} => {}
|
||||
#[allow(unused)]
|
||||
E::Call(function) => {}
|
||||
} => ShaderStages::all(),
|
||||
E::Call(function) => other_infos[function.index()].available_stages,
|
||||
E::ArrayLength(expr) => match *resolver.resolve(expr)? {
|
||||
Ti::Array { .. } => {}
|
||||
Ti::Array { .. } => ShaderStages::all(),
|
||||
ref other => {
|
||||
log::error!("Array length of {:?}", other);
|
||||
return Err(ExpressionError::InvalidArrayType(expr));
|
||||
}
|
||||
},
|
||||
}
|
||||
Ok(())
|
||||
};
|
||||
Ok(stages)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -453,7 +453,7 @@ impl super::Validator {
|
||||
module: &crate::Module,
|
||||
mod_info: &ModuleInfo,
|
||||
) -> Result<FunctionInfo, FunctionError> {
|
||||
let info = mod_info.process_function(fun, module, self.flags)?;
|
||||
let mut info = mod_info.process_function(fun, module, self.flags)?;
|
||||
|
||||
for (var_handle, var) in fun.local_variables.iter() {
|
||||
self.validate_local_var(var, &module.types, &module.constants)
|
||||
@@ -482,8 +482,16 @@ impl super::Validator {
|
||||
self.valid_expression_set.insert(handle.index());
|
||||
}
|
||||
if !self.flags.contains(ValidationFlags::EXPRESSIONS) {
|
||||
if let Err(error) = self.validate_expression(handle, expr, fun, module, &info) {
|
||||
return Err(FunctionError::Expression { handle, error });
|
||||
match self.validate_expression(
|
||||
handle,
|
||||
expr,
|
||||
fun,
|
||||
module,
|
||||
&info,
|
||||
&mod_info.functions,
|
||||
) {
|
||||
Ok(stages) => info.available_stages &= stages,
|
||||
Err(error) => return Err(FunctionError::Expression { handle, error }),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use super::{
|
||||
analyzer::{FunctionInfo, GlobalUse},
|
||||
Disalignment, FunctionError, ModuleInfo, TypeFlags,
|
||||
Disalignment, FunctionError, ModuleInfo, ShaderStages, TypeFlags,
|
||||
};
|
||||
use crate::arena::{Arena, Handle};
|
||||
|
||||
@@ -56,6 +56,8 @@ pub enum EntryPointError {
|
||||
UnexpectedWorkgroupSize,
|
||||
#[error("Workgroup size is out of range")]
|
||||
OutOfRangeWorkgroupSize,
|
||||
#[error("Uses operations forbidden at this stage")]
|
||||
ForbiddenStageOperations,
|
||||
#[error("Global variable {0:?} is used incorrectly as {1:?}")]
|
||||
InvalidGlobalUsage(Handle<crate::GlobalVariable>, GlobalUse),
|
||||
#[error("Bindings for {0:?} conflict with other resource")]
|
||||
@@ -370,8 +372,18 @@ impl super::Validator {
|
||||
return Err(EntryPointError::UnexpectedWorkgroupSize);
|
||||
}
|
||||
|
||||
let stage_bit = match ep.stage {
|
||||
crate::ShaderStage::Vertex => ShaderStages::VERTEX,
|
||||
crate::ShaderStage::Fragment => ShaderStages::FRAGMENT,
|
||||
crate::ShaderStage::Compute => ShaderStages::COMPUTE,
|
||||
};
|
||||
|
||||
let info = self.validate_function(&ep.function, module, &mod_info)?;
|
||||
|
||||
if !info.available_stages.contains(stage_bit) {
|
||||
return Err(EntryPointError::ForbiddenStageOperations);
|
||||
}
|
||||
|
||||
self.location_mask.clear();
|
||||
for (index, fa) in ep.function.arguments.iter().enumerate() {
|
||||
let ctx = VaryingContext {
|
||||
|
||||
@@ -32,6 +32,17 @@ bitflags::bitflags! {
|
||||
}
|
||||
}
|
||||
|
||||
bitflags::bitflags! {
|
||||
/// Validation flags.
|
||||
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
|
||||
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
|
||||
pub struct ShaderStages: u8 {
|
||||
const VERTEX = 0x1;
|
||||
const FRAGMENT = 0x2;
|
||||
const COMPUTE = 0x4;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
|
||||
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
|
||||
pub struct ModuleInfo {
|
||||
|
||||
@@ -8,6 +8,9 @@ expression: output
|
||||
flags: (
|
||||
bits: 7,
|
||||
),
|
||||
available_stages: (
|
||||
bits: 7,
|
||||
),
|
||||
uniformity: (
|
||||
non_uniform_result: Some(5),
|
||||
requirements: (
|
||||
@@ -350,6 +353,9 @@ expression: output
|
||||
flags: (
|
||||
bits: 7,
|
||||
),
|
||||
available_stages: (
|
||||
bits: 7,
|
||||
),
|
||||
uniformity: (
|
||||
non_uniform_result: Some(5),
|
||||
requirements: (
|
||||
|
||||
@@ -8,6 +8,9 @@ expression: output
|
||||
flags: (
|
||||
bits: 7,
|
||||
),
|
||||
available_stages: (
|
||||
bits: 7,
|
||||
),
|
||||
uniformity: (
|
||||
non_uniform_result: Some(44),
|
||||
requirements: (
|
||||
@@ -1006,6 +1009,9 @@ expression: output
|
||||
flags: (
|
||||
bits: 7,
|
||||
),
|
||||
available_stages: (
|
||||
bits: 7,
|
||||
),
|
||||
uniformity: (
|
||||
non_uniform_result: Some(44),
|
||||
requirements: (
|
||||
@@ -2634,6 +2640,9 @@ expression: output
|
||||
flags: (
|
||||
bits: 7,
|
||||
),
|
||||
available_stages: (
|
||||
bits: 7,
|
||||
),
|
||||
uniformity: (
|
||||
non_uniform_result: Some(44),
|
||||
requirements: (
|
||||
|
||||
Reference in New Issue
Block a user