From e7c1415ca49f08a56fa98d580dda63339974e207 Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Tue, 9 Mar 2021 00:54:46 -0500 Subject: [PATCH] [spv-in] really consider variables as pointers --- src/front/spv/convert.rs | 17 +++ src/front/spv/error.rs | 1 - src/front/spv/mod.rs | 187 ++++++++++++------------ tests/out/shadow.ron.snap | 289 ++++++++++++++++++++++++++++++++++++-- 4 files changed, 398 insertions(+), 96 deletions(-) diff --git a/src/front/spv/convert.rs b/src/front/spv/convert.rs index 2695067192..de0df57e2e 100644 --- a/src/front/spv/convert.rs +++ b/src/front/spv/convert.rs @@ -144,3 +144,20 @@ pub fn map_builtin(word: spirv::Word, is_output: bool) -> Result return Err(Error::UnsupportedBuiltIn(word)), }) } + +pub fn map_storage_class(word: spirv::Word) -> Result { + use spirv::StorageClass as Sc; + Ok(match Sc::from_u32(word) { + Some(Sc::Function) => crate::StorageClass::Function, + Some(Sc::Input) => crate::StorageClass::Input, + Some(Sc::Output) => crate::StorageClass::Output, + Some(Sc::Private) => crate::StorageClass::Private, + Some(Sc::UniformConstant) => crate::StorageClass::Handle, + Some(Sc::StorageBuffer) => crate::StorageClass::Storage, + // we expect the `Storage` case to be filtered out before calling this function. + Some(Sc::Uniform) => crate::StorageClass::Uniform, + Some(Sc::Workgroup) => crate::StorageClass::WorkGroup, + Some(Sc::PushConstant) => crate::StorageClass::PushConstant, + _ => return Err(Error::UnsupportedStorageClass(word)), + }) +} diff --git a/src/front/spv/error.rs b/src/front/spv/error.rs index b07053f986..85bb5f502b 100644 --- a/src/front/spv/error.rs +++ b/src/front/spv/error.rs @@ -34,7 +34,6 @@ pub enum Error { InvalidSign(spirv::Word), InvalidInnerType(spirv::Word), InvalidVectorSize(spirv::Word), - InvalidVariableClass(spirv::StorageClass), InvalidAccessType(spirv::Word), InvalidAccess(crate::Expression), InvalidAccessIndex(spirv::Word), diff --git a/src/front/spv/mod.rs b/src/front/spv/mod.rs index d43c5f882c..e20f9cdb1f 100644 --- a/src/front/spv/mod.rs +++ b/src/front/spv/mod.rs @@ -610,12 +610,7 @@ impl> Parser { let result_type_id = self.next()?; let result_id = self.next()?; - let storage = self.next()?; - match spirv::StorageClass::from_u32(storage) { - Some(spirv::StorageClass::Function) => (), - Some(class) => return Err(Error::InvalidVariableClass(class)), - None => return Err(Error::UnsupportedStorageClass(storage)), - } + let _storage_class = self.next()?; let init = if inst.wc > 4 { inst.expect(5)?; let init_id = self.next()?; @@ -632,9 +627,13 @@ impl> Parser { if let Some(ref name) = name { log::debug!("\t\t\tid={} name={}", result_id, name); } + let lookup_ty = self.lookup_type.lookup(result_type_id)?; let var_handle = local_arena.append(crate::LocalVariable { name, - ty: self.lookup_type.lookup(result_type_id)?.handle, + ty: match type_arena[lookup_ty.handle].inner { + crate::TypeInner::Pointer { base, .. } => base, + _ => lookup_ty.handle, + }, init, }); @@ -694,10 +693,13 @@ impl> Parser { let base_id = self.next()?; log::trace!("\t\t\tlooking up expr {:?}", base_id); let mut acex = { - let expr = self.lookup_expression.lookup(base_id)?; + // the base type has to be a pointer, + // so we derefernce it here for the traversal + let lexp = self.lookup_expression.lookup(base_id)?; + let lty = self.lookup_type.lookup(lexp.type_id)?; AccessExpression { - base_handle: expr.handle, - type_id: expr.type_id, + base_handle: lexp.handle, + type_id: lty.base_id.ok_or(Error::InvalidAccessType(lexp.type_id))?, } }; for _ in 4..inst.wc { @@ -714,7 +716,10 @@ impl> Parser { kind: crate::ScalarKind::Sint, .. } => (), - _ => return Err(Error::UnsupportedType(index_type_handle)), + ref other => { + log::warn!("access index type {:?}", other); + return Err(Error::UnsupportedType(index_type_handle)); + } } log::trace!("\t\t\tlooking up type {:?}", acex.type_id); let type_lookup = self.lookup_type.lookup(acex.type_id)?; @@ -753,9 +758,7 @@ impl> Parser { .ok_or(Error::InvalidAccessType(acex.type_id))?, } } - crate::TypeInner::Array { .. } - | crate::TypeInner::Vector { .. } - | crate::TypeInner::Matrix { .. } => AccessExpression { + _ => AccessExpression { base_handle: expressions.append(crate::Expression::Access { base: acex.base_handle, index: index_expr.handle, @@ -764,7 +767,6 @@ impl> Parser { .base_id .ok_or(Error::InvalidAccessType(acex.type_id))?, }, - _ => return Err(Error::UnsupportedType(type_lookup.handle)), }; } @@ -895,7 +897,10 @@ impl> Parser { | crate::TypeInner::Matrix { .. } => type_lookup .base_id .ok_or(Error::InvalidAccessType(lexp.type_id))?, - _ => return Err(Error::UnsupportedType(type_lookup.handle)), + ref other => { + log::warn!("composite type {:?}", other); + return Err(Error::UnsupportedType(type_lookup.handle)); + } }; lexp = LookupExpression { handle: expressions.append(crate::Expression::AccessIndex { @@ -2191,15 +2196,44 @@ impl> Parser { fn parse_type_pointer( &mut self, inst: Instruction, - _module: &mut crate::Module, + module: &mut crate::Module, ) -> Result<(), Error> { self.switch(ModuleState::Type, inst.op)?; inst.expect(4)?; let id = self.next()?; - let _storage = self.next()?; + let storage_class = self.next()?; let type_id = self.next()?; - let type_lookup = self.lookup_type.lookup(type_id)?.clone(); - self.lookup_type.insert(id, type_lookup); // don't register pointers in the IR + + let decor = self.future_decor.remove(&id); + let base_lookup_ty = self.lookup_type.lookup(type_id)?; + let class = match module.types[base_lookup_ty.handle].inner { + crate::TypeInner::Pointer { class, .. } + | crate::TypeInner::ValuePointer { class, .. } => class, + _ if self + .lookup_storage_buffer_types + .contains(&base_lookup_ty.handle) => + { + crate::StorageClass::Storage + } + _ => map_storage_class(storage_class)?, + }; + + // Don't bother with pointer stuff for `Handle` types. + let lookup_ty = if class == crate::StorageClass::Handle { + base_lookup_ty.clone() + } else { + LookupType { + handle: module.types.append(crate::Type { + name: decor.and_then(|dec| dec.name), + inner: crate::TypeInner::Pointer { + base: base_lookup_ty.handle, + class, + }, + }), + base_id: Some(type_id), + } + }; + self.lookup_type.insert(id, lookup_ty); Ok(()) } @@ -2584,7 +2618,10 @@ impl> Parser { crate::ConstantInner::Composite { ty, components } } //TODO: handle matrices, arrays, and structures - _ => return Err(Error::UnsupportedType(type_lookup.handle)), + ref other => { + log::warn!("null constant type {:?}", other); + return Err(Error::UnsupportedType(type_lookup.handle)); + } }; self.lookup_constant.insert( @@ -2644,43 +2681,25 @@ impl> Parser { } else { None }; - let lookup_type = self.lookup_type.lookup(type_id)?; let dec = self.future_decor.remove(&id).unwrap_or_default(); - let class = { - use spirv::StorageClass as Sc; - match Sc::from_u32(storage_class) { - Some(Sc::Function) => crate::StorageClass::Function, - Some(Sc::Input) => crate::StorageClass::Input, - Some(Sc::Output) => crate::StorageClass::Output, - Some(Sc::Private) => crate::StorageClass::Private, - Some(Sc::UniformConstant) => crate::StorageClass::Handle, - Some(Sc::StorageBuffer) => crate::StorageClass::Storage, - Some(Sc::Uniform) => { - if self - .lookup_storage_buffer_types - .contains(&lookup_type.handle) - { - crate::StorageClass::Storage - } else { - crate::StorageClass::Uniform - } - } - Some(Sc::Workgroup) => crate::StorageClass::WorkGroup, - Some(Sc::PushConstant) => crate::StorageClass::PushConstant, - _ => return Err(Error::UnsupportedStorageClass(storage_class)), + let mut effective_ty = self.lookup_type.lookup(type_id)?.handle; + let is_storage = match module.types[effective_ty].inner { + crate::TypeInner::Pointer { base, class } => { + effective_ty = base; + class == crate::StorageClass::Storage } - }; - - let ty_inner = &module.types[lookup_type.handle].inner; - let is_storage = match *ty_inner { - crate::TypeInner::Struct { .. } => class == crate::StorageClass::Storage, crate::TypeInner::Image { class: crate::ImageClass::Storage(_), .. } => true, _ => false, }; + let class = if self.lookup_storage_buffer_types.contains(&effective_ty) { + crate::StorageClass::Storage + } else { + map_storage_class(storage_class)? + }; let storage_access = if is_storage { let mut access = crate::StorageAccess::all(); @@ -2696,48 +2715,45 @@ impl> Parser { }; let binding = dec.get_binding(class == crate::StorageClass::Output); - let ty = match binding { + if let Some(crate::Binding::BuiltIn(built_in)) = binding { // SPIR-V only cares about some of the built-in types being integer. // Naga requires them to be strictly unsigned, so we have to patch it. - Some(crate::Binding::BuiltIn(built_in)) => { - let needs_inner_uint = match built_in { - crate::BuiltIn::BaseInstance - | crate::BuiltIn::BaseVertex - | crate::BuiltIn::InstanceIndex - | crate::BuiltIn::SampleIndex - | crate::BuiltIn::VertexIndex - | crate::BuiltIn::LocalInvocationIndex => Some(crate::TypeInner::Scalar { - kind: crate::ScalarKind::Uint, - width: 4, - }), - crate::BuiltIn::GlobalInvocationId - | crate::BuiltIn::LocalInvocationId - | crate::BuiltIn::WorkGroupId - | crate::BuiltIn::WorkGroupSize => Some(crate::TypeInner::Vector { - size: crate::VectorSize::Tri, - kind: crate::ScalarKind::Uint, - width: 4, - }), - _ => None, - }; - match (needs_inner_uint, ty_inner.scalar_kind()) { - (Some(inner), Some(crate::ScalarKind::Sint)) => { - log::warn!("Treating {:?} as unsigned", built_in); - module - .types - .fetch_or_append(crate::Type { name: None, inner }) - } - _ => lookup_type.handle, - } + let needs_inner_uint = match built_in { + crate::BuiltIn::BaseInstance + | crate::BuiltIn::BaseVertex + | crate::BuiltIn::InstanceIndex + | crate::BuiltIn::SampleIndex + | crate::BuiltIn::VertexIndex + | crate::BuiltIn::LocalInvocationIndex => Some(crate::TypeInner::Scalar { + kind: crate::ScalarKind::Uint, + width: 4, + }), + crate::BuiltIn::GlobalInvocationId + | crate::BuiltIn::LocalInvocationId + | crate::BuiltIn::WorkGroupId + | crate::BuiltIn::WorkGroupSize => Some(crate::TypeInner::Vector { + size: crate::VectorSize::Tri, + kind: crate::ScalarKind::Uint, + width: 4, + }), + _ => None, + }; + if let (Some(inner), Some(crate::ScalarKind::Sint)) = ( + needs_inner_uint, + module.types[effective_ty].inner.scalar_kind(), + ) { + log::warn!("Treating {:?} as unsigned", built_in); + effective_ty = module + .types + .fetch_or_append(crate::Type { name: None, inner }); } - _ => lookup_type.handle, - }; + } let var = crate::GlobalVariable { name: dec.name, class, binding, - ty, + ty: effective_ty, init, interpolation: dec.interpolation, storage_access, @@ -2746,10 +2762,7 @@ impl> Parser { self.lookup_variable .insert(id, LookupVariable { handle, type_id }); - if module.types[lookup_type.handle] - .inner - .can_comparison_sample() - { + if module.types[effective_ty].inner.can_comparison_sample() { log::debug!("\t\ttracking {:?} for sampling properties", handle); self.handle_sampling .insert(handle, image::SamplingFlags::empty()); diff --git a/tests/out/shadow.ron.snap b/tests/out/shadow.ron.snap index 3d4688cb7d..d9ddcc9108 100644 --- a/tests/out/shadow.ron.snap +++ b/tests/out/shadow.ron.snap @@ -81,6 +81,20 @@ expression: output width: 4, ), ), + ( + name: None, + inner: Pointer( + base: 2, + class: Function, + ), + ), + ( + name: None, + inner: Pointer( + base: 3, + class: Function, + ), + ), ( name: None, inner: Vector( @@ -97,11 +111,32 @@ expression: output ( name: Some("num_lights"), span: None, - ty: 11, + ty: 13, ), ], ), ), + ( + name: None, + inner: Pointer( + base: 14, + class: Uniform, + ), + ), + ( + name: None, + inner: Pointer( + base: 13, + class: Uniform, + ), + ), + ( + name: None, + inner: Pointer( + base: 3, + class: Uniform, + ), + ), ( name: None, inner: Matrix( @@ -118,7 +153,7 @@ expression: output ( name: Some("proj"), span: None, - ty: 13, + ty: 18, ), ( name: Some("pos"), @@ -136,7 +171,7 @@ expression: output ( name: None, inner: Array( - base: 14, + base: 19, size: Dynamic, stride: Some(96), ), @@ -149,11 +184,249 @@ expression: output ( name: Some("data"), span: None, - ty: 15, + ty: 20, ), ], ), ), + ( + name: None, + inner: Pointer( + base: 21, + class: Storage, + ), + ), + ( + name: None, + inner: Pointer( + base: 20, + class: Uniform, + ), + ), + ( + name: None, + inner: Pointer( + base: 19, + class: Uniform, + ), + ), + ( + name: None, + inner: Pointer( + base: 18, + class: Uniform, + ), + ), + ( + name: None, + inner: Pointer( + base: 4, + class: Input, + ), + ), + ( + name: None, + inner: Pointer( + base: 2, + class: Input, + ), + ), + ( + name: None, + inner: Pointer( + base: 20, + class: Uniform, + ), + ), + ( + name: None, + inner: Pointer( + base: 19, + class: Uniform, + ), + ), + ( + name: None, + inner: Pointer( + base: 4, + class: Uniform, + ), + ), + ( + name: None, + inner: Pointer( + base: 1, + class: Uniform, + ), + ), + ( + name: None, + inner: Pointer( + base: 20, + class: Uniform, + ), + ), + ( + name: None, + inner: Pointer( + base: 19, + class: Uniform, + ), + ), + ( + name: None, + inner: Pointer( + base: 4, + class: Uniform, + ), + ), + ( + name: None, + inner: Pointer( + base: 1, + class: Uniform, + ), + ), + ( + name: None, + inner: Pointer( + base: 20, + class: Uniform, + ), + ), + ( + name: None, + inner: Pointer( + base: 19, + class: Uniform, + ), + ), + ( + name: None, + inner: Pointer( + base: 4, + class: Uniform, + ), + ), + ( + name: None, + inner: Pointer( + base: 1, + class: Uniform, + ), + ), + ( + name: None, + inner: Pointer( + base: 1, + class: Input, + ), + ), + ( + name: None, + inner: Pointer( + base: 1, + class: Input, + ), + ), + ( + name: None, + inner: Pointer( + base: 1, + class: Input, + ), + ), + ( + name: None, + inner: Pointer( + base: 20, + class: Uniform, + ), + ), + ( + name: None, + inner: Pointer( + base: 19, + class: Uniform, + ), + ), + ( + name: None, + inner: Pointer( + base: 4, + class: Uniform, + ), + ), + ( + name: None, + inner: Pointer( + base: 1, + class: Uniform, + ), + ), + ( + name: None, + inner: Pointer( + base: 20, + class: Uniform, + ), + ), + ( + name: None, + inner: Pointer( + base: 19, + class: Uniform, + ), + ), + ( + name: None, + inner: Pointer( + base: 4, + class: Uniform, + ), + ), + ( + name: None, + inner: Pointer( + base: 1, + class: Uniform, + ), + ), + ( + name: None, + inner: Pointer( + base: 20, + class: Uniform, + ), + ), + ( + name: None, + inner: Pointer( + base: 19, + class: Uniform, + ), + ), + ( + name: None, + inner: Pointer( + base: 4, + class: Uniform, + ), + ), + ( + name: None, + inner: Pointer( + base: 1, + class: Uniform, + ), + ), + ( + name: None, + inner: Pointer( + base: 4, + class: Output, + ), + ), ( name: None, inner: Image( @@ -495,7 +768,7 @@ expression: output group: 0, binding: 2, )), - ty: 17, + ty: 56, init: None, interpolation: None, storage_access: ( @@ -509,7 +782,7 @@ expression: output group: 0, binding: 3, )), - ty: 18, + ty: 57, init: None, interpolation: None, storage_access: ( @@ -523,7 +796,7 @@ expression: output group: 0, binding: 0, )), - ty: 12, + ty: 14, init: None, interpolation: None, storage_access: ( @@ -537,7 +810,7 @@ expression: output group: 0, binding: 1, )), - ty: 16, + ty: 21, init: None, interpolation: None, storage_access: (