From ce49afa3919ee1780fb6452d5ef82c05ef1e6f7e Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Tue, 27 Oct 2020 09:10:55 -0400 Subject: [PATCH] Clean up the storage classes (#245) --- src/back/glsl.rs | 20 +++--- src/back/msl.rs | 14 ++-- src/back/spv/layout.rs | 4 ++ src/back/spv/writer.rs | 11 ++- src/front/glsl/parser.rs | 4 +- src/front/spv/convert.rs | 15 ---- src/front/spv/mod.rs | 34 +++++++-- src/front/wgsl/conv.rs | 3 +- src/front/wgsl/mod.rs | 15 +++- src/lib.rs | 16 ++++- src/proc/interface.rs | 2 +- src/proc/validator.rs | 150 ++++++++++++++++++++++++--------------- test-data/boids.wgsl | 4 +- test-data/quad.wgsl | 4 +- 14 files changed, 184 insertions(+), 112 deletions(-) diff --git a/src/back/glsl.rs b/src/back/glsl.rs index fee127199b..8463b1cff2 100644 --- a/src/back/glsl.rs +++ b/src/back/glsl.rs @@ -112,6 +112,7 @@ bitflags::bitflags! { const IMAGE_LOAD_STORE = 1 << 8; const CONSERVATIVE_DEPTH = 1 << 9; const TEXTURE_1D = 1 << 10; + const PUSH_CONSTANT = 1 << 11; } } @@ -364,7 +365,7 @@ pub fn write<'a>( } let block = match global.class { - StorageClass::StorageBuffer | StorageClass::Uniform => true, + StorageClass::Storage | StorageClass::Uniform => true, _ => false, }; @@ -557,14 +558,15 @@ pub fn write<'a>( let name = if let Some(ref binding) = global.binding { let prefix = match global.class { - StorageClass::Constant => "const", StorageClass::Function => "fn", StorageClass::Input => "in", StorageClass::Output => "out", StorageClass::Private => "priv", - StorageClass::StorageBuffer => "buffer", + StorageClass::Storage => "buffer", StorageClass::Uniform => "uniform", + StorageClass::Handle => "handle", StorageClass::WorkGroup => "wg", + StorageClass::PushConstant => "pc", }; match binding { @@ -606,7 +608,7 @@ pub fn write<'a>( } let block = match global.class { - StorageClass::StorageBuffer | StorageClass::Uniform => { + StorageClass::Storage | StorageClass::Uniform => { Some(format!("global_block_{}", handle.index())) } _ => None, @@ -1492,22 +1494,24 @@ fn write_storage_class( manager: &mut FeaturesManager, ) -> Result<&'static str, Error> { Ok(match class { - StorageClass::Constant => "", StorageClass::Function => "", StorageClass::Input => "in ", StorageClass::Output => "out ", StorageClass::Private => "", - StorageClass::StorageBuffer => { + StorageClass::Storage => { manager.request(Features::BUFFER_STORAGE); - "buffer " } StorageClass::Uniform => "uniform ", + StorageClass::Handle => "uniform ", StorageClass::WorkGroup => { manager.request(Features::COMPUTE_SHADER); - "shared " } + StorageClass::PushConstant => { + manager.request(Features::PUSH_CONSTANT); + "" + } }) } diff --git a/src/back/msl.rs b/src/back/msl.rs index d7ff76069e..007fa85aea 100644 --- a/src/back/msl.rs +++ b/src/back/msl.rs @@ -254,9 +254,7 @@ impl<'a> TypedGlobalVariable<'a> { let (space_qualifier, reference) = match ty.inner { crate::TypeInner::Struct { .. } => match var.class { - crate::StorageClass::Constant - | crate::StorageClass::Uniform - | crate::StorageClass::StorageBuffer => { + crate::StorageClass::Uniform | crate::StorageClass::Storage => { let space = if self.usage.contains(crate::GlobalUse::STORE) { "device " } else { @@ -837,9 +835,13 @@ impl Writer { let base_name = module.types[base].name.or_index(base); let class_name = match class { Sc::Input | Sc::Output => continue, - Sc::Constant | Sc::Uniform => "constant", - Sc::StorageBuffer => "device", - Sc::Private | Sc::Function | Sc::WorkGroup => "", + Sc::Uniform => "constant", + Sc::Storage => "device", + Sc::Handle + | Sc::Private + | Sc::Function + | Sc::WorkGroup + | Sc::PushConstant => "", }; write!(self.out, "typedef {} {} *{}", class_name, base_name, name)?; } diff --git a/src/back/spv/layout.rs b/src/back/spv/layout.rs index e4c3d1cc0d..006e785317 100644 --- a/src/back/spv/layout.rs +++ b/src/back/spv/layout.rs @@ -24,6 +24,10 @@ impl PhysicalLayout { sink.extend(iter::once(self.bound)); sink.extend(iter::once(self.instruction_schema)); } + + pub(super) fn supports_storage_buffers(&self) -> bool { + self.version >= 0x10300 + } } impl LogicalLayout { diff --git a/src/back/spv/writer.rs b/src/back/spv/writer.rs index fa072a6698..b0e0a400c6 100644 --- a/src/back/spv/writer.rs +++ b/src/back/spv/writer.rs @@ -417,14 +417,19 @@ impl Writer { fn parse_to_spirv_storage_class(&self, class: crate::StorageClass) -> spirv::StorageClass { match class { - crate::StorageClass::Constant => spirv::StorageClass::UniformConstant, + crate::StorageClass::Handle => spirv::StorageClass::UniformConstant, crate::StorageClass::Function => spirv::StorageClass::Function, crate::StorageClass::Input => spirv::StorageClass::Input, crate::StorageClass::Output => spirv::StorageClass::Output, crate::StorageClass::Private => spirv::StorageClass::Private, - crate::StorageClass::StorageBuffer => spirv::StorageClass::StorageBuffer, - crate::StorageClass::Uniform => spirv::StorageClass::Uniform, + crate::StorageClass::Storage if self.physical_layout.supports_storage_buffers() => { + spirv::StorageClass::StorageBuffer + } + crate::StorageClass::Storage | crate::StorageClass::Uniform => { + spirv::StorageClass::Uniform + } crate::StorageClass::WorkGroup => spirv::StorageClass::Workgroup, + crate::StorageClass::PushConstant => spirv::StorageClass::PushConstant, } } diff --git a/src/front/glsl/parser.rs b/src/front/glsl/parser.rs index 4815adacd4..0ec67099ea 100644 --- a/src/front/glsl/parser.rs +++ b/src/front/glsl/parser.rs @@ -596,9 +596,7 @@ pomelo! { // single_type_qualifier ::= invariant_qualifier; // single_type_qualifier ::= precise_qualifier; - storage_qualifier ::= Const { - StorageClass::Constant - } + // storage_qualifier ::= Const // storage_qualifier ::= InOut; storage_qualifier ::= In { StorageClass::Input diff --git a/src/front/spv/convert.rs b/src/front/spv/convert.rs index c7b5cfa483..6036d08946 100644 --- a/src/front/spv/convert.rs +++ b/src/front/spv/convert.rs @@ -43,21 +43,6 @@ pub fn map_vector_size(word: spirv::Word) -> Result { } } -pub fn map_storage_class(word: spirv::Word) -> Result { - use spirv::StorageClass as Sc; - match Sc::from_u32(word) { - Some(Sc::UniformConstant) => Ok(crate::StorageClass::Constant), - Some(Sc::Function) => Ok(crate::StorageClass::Function), - Some(Sc::Input) => Ok(crate::StorageClass::Input), - Some(Sc::Output) => Ok(crate::StorageClass::Output), - Some(Sc::Private) => Ok(crate::StorageClass::Private), - Some(Sc::StorageBuffer) => Ok(crate::StorageClass::StorageBuffer), - Some(Sc::Uniform) => Ok(crate::StorageClass::Uniform), - Some(Sc::Workgroup) => Ok(crate::StorageClass::WorkGroup), - _ => Err(Error::UnsupportedStorageClass(word)), - } -} - pub fn map_image_dim(word: spirv::Word) -> Result { use spirv::Dim as D; match D::from_u32(word) { diff --git a/src/front/spv/mod.rs b/src/front/spv/mod.rs index 38ca9aea0b..61dc3d5fd8 100644 --- a/src/front/spv/mod.rs +++ b/src/front/spv/mod.rs @@ -2274,19 +2274,39 @@ impl> Parser { .future_decor .remove(&id) .ok_or(Error::InvalidBinding(id))?; - let class = map_storage_class(storage_class)?; + + let class = { + use spirv::StorageClass as Sc; + match Sc::from_u32(storage_class) { + Some(Sc::Function) => crate::StorageClass::Function, + Some(Sc::Input) => crate::StorageClass::Input, + Some(Sc::Output) => crate::StorageClass::Output, + Some(Sc::Private) => crate::StorageClass::Private, + Some(Sc::UniformConstant) => crate::StorageClass::Handle, + Some(Sc::StorageBuffer) => crate::StorageClass::Storage, + Some(Sc::Uniform) => { + if self + .lookup_storage_buffer_types + .contains(&lookup_type.handle) + { + crate::StorageClass::Storage + } else { + crate::StorageClass::Uniform + } + } + Some(Sc::Workgroup) => crate::StorageClass::WorkGroup, + Some(Sc::PushConstant) => crate::StorageClass::PushConstant, + _ => return Err(Error::UnsupportedStorageClass(storage_class)), + } + }; + let binding = match (class, &module.types[lookup_type.handle].inner) { (crate::StorageClass::Input, &crate::TypeInner::Struct { .. }) | (crate::StorageClass::Output, &crate::TypeInner::Struct { .. }) => None, _ => Some(dec.get_binding().ok_or(Error::InvalidBinding(id))?), }; let is_storage = match module.types[lookup_type.handle].inner { - crate::TypeInner::Struct { .. } => match class { - crate::StorageClass::StorageBuffer => true, - _ => self - .lookup_storage_buffer_types - .contains(&lookup_type.handle), - }, + crate::TypeInner::Struct { .. } => class == crate::StorageClass::Storage, crate::TypeInner::Image { class: crate::ImageClass::Storage(_), .. diff --git a/src/front/wgsl/conv.rs b/src/front/wgsl/conv.rs index 0b25ad9435..6fdea89e7a 100644 --- a/src/front/wgsl/conv.rs +++ b/src/front/wgsl/conv.rs @@ -4,8 +4,9 @@ pub fn map_storage_class(word: &str) -> Result> { match word { "in" => Ok(crate::StorageClass::Input), "out" => Ok(crate::StorageClass::Output), + "private" => Ok(crate::StorageClass::Private), "uniform" => Ok(crate::StorageClass::Uniform), - "storage_buffer" => Ok(crate::StorageClass::StorageBuffer), + "storage" => Ok(crate::StorageClass::Storage), _ => Err(Error::UnknownStorageClass(word)), } } diff --git a/src/front/wgsl/mod.rs b/src/front/wgsl/mod.rs index 5e8a3824e8..05c2749946 100644 --- a/src/front/wgsl/mod.rs +++ b/src/front/wgsl/mod.rs @@ -930,8 +930,17 @@ impl Parser { lexer.expect(Token::Separator(':'))?; let ty = self.parse_type_decl(lexer, None, type_arena)?; let access = match class { - Some(crate::StorageClass::StorageBuffer) => crate::StorageAccess::all(), - Some(crate::StorageClass::Constant) => crate::StorageAccess::LOAD, + Some(crate::StorageClass::Storage) => crate::StorageAccess::all(), + Some(crate::StorageClass::Handle) => { + match type_arena[ty].inner { + //TODO: RW textures + crate::TypeInner::Image { + class: crate::ImageClass::Storage(_), + .. + } => crate::StorageAccess::LOAD, + _ => crate::StorageAccess::empty(), + } + } _ => crate::StorageAccess::empty(), }; if lexer.skip(Token::Operation('=')) { @@ -1708,7 +1717,7 @@ impl Parser { crate::BuiltIn::Position => crate::StorageClass::Output, _ => unimplemented!(), }, - _ => crate::StorageClass::Private, + _ => crate::StorageClass::Handle, }, }; let var_handle = module.global_variables.append(crate::GlobalVariable { diff --git a/src/lib.rs b/src/lib.rs index e0a75e72ec..63ca33f9a2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -108,14 +108,24 @@ pub enum ShaderStage { #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[allow(missing_docs)] // The names are self evident pub enum StorageClass { - Constant, + /// Function locals. Function, + /// Pipeline input, per invocation. Input, + /// Pipeline output, per invocation, mutable. Output, + /// Private data, per invocation, mutable. Private, - StorageBuffer, - Uniform, + /// Workgroup shared data, mutable. WorkGroup, + /// Uniform buffer data. + Uniform, + /// Storage buffer data, potentially mutable. + Storage, + /// Opaque handles, such as samplers and images. + Handle, + /// Push constants. + PushConstant, } /// Built-in inputs and outputs. diff --git a/src/proc/interface.rs b/src/proc/interface.rs index 33767f597f..670f2da506 100644 --- a/src/proc/interface.rs +++ b/src/proc/interface.rs @@ -218,7 +218,7 @@ mod tests { fn global_use_scan() { let test_global = GlobalVariable { name: None, - class: StorageClass::Constant, + class: StorageClass::Uniform, binding: None, ty: Handle::new(std::num::NonZeroU32::new(1).unwrap()), interpolation: None, diff --git a/src/proc/validator.rs b/src/proc/validator.rs index ffcf49492b..bfb2858c61 100644 --- a/src/proc/validator.rs +++ b/src/proc/validator.rs @@ -21,8 +21,11 @@ pub enum GlobalVariableError { InvalidType, #[error("Interpolation is not valid")] InvalidInterpolation, - #[error("Storage access flags are invalid")] - InvalidStorageAccess, + #[error("Storage access {seen:?} exceed the allowed {allowed:?}")] + InvalidStorageAccess { + allowed: crate::StorageAccess, + seen: crate::StorageAccess, + }, #[error("Binding decoration is missing or not applicable")] InvalidBinding, #[error("Binding is out of range")] @@ -63,8 +66,12 @@ pub enum ValidationError { InvalidTypeWidth(crate::ScalarKind, crate::Bytes), #[error("The type handle {0:?} can not be resolved")] UnresolvedType(Handle), - #[error("Global variable {0:?} is invalid: {1:?}")] - GlobalVariable(Handle, GlobalVariableError), + #[error("Global variable {handle:?} '{name}' is invalid: {error:?}")] + GlobalVariable { + handle: Handle, + name: String, + error: GlobalVariableError, + }, #[error("Function {0:?} is invalid: {1:?}")] Function(Handle, FunctionError), #[error("Entry point {name} at {stage:?} is invalid: {error:?}")] @@ -75,6 +82,41 @@ pub enum ValidationError { }, } +impl crate::GlobalVariable { + fn forbid_interpolation(&self) -> Result<(), GlobalVariableError> { + match self.interpolation { + Some(_) => Err(GlobalVariableError::InvalidInterpolation), + None => Ok(()), + } + } + + fn check_resource(&self) -> Result<(), GlobalVariableError> { + match self.binding { + Some(crate::Binding::BuiltIn(_)) => {} // validated per entry point + Some(crate::Binding::Resource { group, binding }) => { + if group > MAX_BIND_GROUPS || binding > MAX_BIND_INDICES { + return Err(GlobalVariableError::OutOfRangeBinding); + } + } + Some(crate::Binding::Location(_)) | None => { + return Err(GlobalVariableError::InvalidBinding) + } + } + self.forbid_interpolation() + } +} + +fn storage_usage(access: crate::StorageAccess) -> crate::GlobalUse { + let mut storage_usage = crate::GlobalUse::empty(); + if access.contains(crate::StorageAccess::LOAD) { + storage_usage |= crate::GlobalUse::LOAD; + } + if access.contains(crate::StorageAccess::STORE) { + storage_usage |= crate::GlobalUse::STORE; + } + storage_usage +} + impl Validator { /// Construct a new validator instance. pub fn new() -> Self { @@ -89,15 +131,13 @@ impl Validator { types: &Arena, ) -> Result<(), GlobalVariableError> { log::debug!("var {:?}", var); - let is_storage = match var.class { + let allowed_storage_access = match var.class { crate::StorageClass::Function => return Err(GlobalVariableError::InvalidUsage), crate::StorageClass::Input | crate::StorageClass::Output => { match var.binding { Some(crate::Binding::BuiltIn(_)) => { // validated per entry point - if var.interpolation.is_some() { - return Err(GlobalVariableError::InvalidInterpolation); - } + var.forbid_interpolation()? } Some(crate::Binding::Location(loc)) => { if loc > MAX_LOCATIONS { @@ -117,56 +157,53 @@ impl Validator { match types[var.ty].inner { //TODO: check the member types crate::TypeInner::Struct { members: _ } => { - if var.interpolation.is_some() { - return Err(GlobalVariableError::InvalidInterpolation); - } + var.forbid_interpolation()? } _ => return Err(GlobalVariableError::InvalidType), } } } - false + crate::StorageAccess::empty() } - crate::StorageClass::Constant - | crate::StorageClass::StorageBuffer - | crate::StorageClass::Uniform => { - match var.binding { - Some(crate::Binding::BuiltIn(_)) => {} // validated per entry point - Some(crate::Binding::Resource { group, binding }) => { - if group > MAX_BIND_GROUPS || binding > MAX_BIND_INDICES { - return Err(GlobalVariableError::OutOfRangeBinding); - } - } - Some(crate::Binding::Location(_)) | None => { - return Err(GlobalVariableError::InvalidBinding) - } - } - if var.interpolation.is_some() { - return Err(GlobalVariableError::InvalidInterpolation); - } - //TODO: prevent `Uniform` storage class with `STORE` access + crate::StorageClass::Storage => { + var.check_resource()?; + crate::StorageAccess::all() + } + crate::StorageClass::Uniform => { + var.check_resource()?; + crate::StorageAccess::empty() + } + crate::StorageClass::Handle => { + var.check_resource()?; match types[var.ty].inner { - crate::TypeInner::Struct { .. } - | crate::TypeInner::Image { + crate::TypeInner::Image { class: crate::ImageClass::Storage(_), .. - } => true, - _ => false, + } => crate::StorageAccess::all(), + _ => crate::StorageAccess::empty(), } } crate::StorageClass::Private | crate::StorageClass::WorkGroup => { if var.binding.is_some() { return Err(GlobalVariableError::InvalidBinding); } - if var.interpolation.is_some() { - return Err(GlobalVariableError::InvalidInterpolation); - } - false + var.forbid_interpolation()?; + crate::StorageAccess::empty() + } + crate::StorageClass::PushConstant => { + //TODO + return Err(GlobalVariableError::InvalidStorageAccess { + allowed: crate::StorageAccess::empty(), + seen: crate::StorageAccess::empty(), + }); } }; - if !is_storage && !var.storage_access.is_empty() { - return Err(GlobalVariableError::InvalidStorageAccess); + if !allowed_storage_access.contains(var.storage_access) { + return Err(GlobalVariableError::InvalidStorageAccess { + allowed: allowed_storage_access, + seen: var.storage_access, + }); } Ok(()) @@ -291,26 +328,19 @@ impl Validator { location_out_mask |= mask; crate::GlobalUse::LOAD | crate::GlobalUse::STORE } - crate::StorageClass::Constant => crate::GlobalUse::LOAD, - crate::StorageClass::Uniform | crate::StorageClass::StorageBuffer => { - //TODO: built-in checks? - let mut storage_usage = crate::GlobalUse::empty(); - if var.storage_access.contains(crate::StorageAccess::LOAD) { - storage_usage |= crate::GlobalUse::LOAD; - } - if var.storage_access.contains(crate::StorageAccess::STORE) { - storage_usage |= crate::GlobalUse::STORE; - } - if storage_usage.is_empty() { - // its a uniform buffer - crate::GlobalUse::LOAD - } else { - storage_usage - } - } + crate::StorageClass::Uniform => crate::GlobalUse::LOAD, + crate::StorageClass::Storage => storage_usage(var.storage_access), + crate::StorageClass::Handle => match module.types[var.ty].inner { + crate::TypeInner::Image { + class: crate::ImageClass::Storage(_), + .. + } => storage_usage(var.storage_access), + _ => crate::GlobalUse::LOAD, + }, crate::StorageClass::Private | crate::StorageClass::WorkGroup => { crate::GlobalUse::all() } + crate::StorageClass::PushConstant => crate::GlobalUse::LOAD, }; if !allowed_usage.contains(usage) { log::warn!("\tUsage error for: {:?}", var); @@ -384,7 +414,11 @@ impl Validator { for (var_handle, var) in module.global_variables.iter() { self.validate_global_var(var, &module.types) - .map_err(|e| ValidationError::GlobalVariable(var_handle, e))?; + .map_err(|error| ValidationError::GlobalVariable { + handle: var_handle, + name: var.name.clone().unwrap_or_default(), + error, + })?; } for (fun_handle, fun) in module.functions.iter() { diff --git a/test-data/boids.wgsl b/test-data/boids.wgsl index 5df5e60611..9bc98052e9 100644 --- a/test-data/boids.wgsl +++ b/test-data/boids.wgsl @@ -61,8 +61,8 @@ type Particles = struct { }; [[group(0), binding(0)]] var params : SimParams; -[[group(0), binding(1)]] var particlesA : Particles; -[[group(0), binding(2)]] var particlesB : Particles; +[[group(0), binding(1)]] var particlesA : Particles; +[[group(0), binding(2)]] var particlesB : Particles; [[builtin(global_invocation_id)]] var gl_GlobalInvocationID : vec3; diff --git a/test-data/quad.wgsl b/test-data/quad.wgsl index 3e56d317f3..16cbd2698a 100644 --- a/test-data/quad.wgsl +++ b/test-data/quad.wgsl @@ -14,8 +14,8 @@ fn main() -> void { # fragment [[location(0)]] var v_uv : vec2; -[[group(0), binding(0)]] var u_texture : texture_sampled_2d; -[[group(0), binding(1)]] var u_sampler : sampler; +[[group(0), binding(0)]] var u_texture : texture_sampled_2d; +[[group(0), binding(1)]] var u_sampler : sampler; [[location(0)]] var o_color : vec4; [[stage(fragment)]]