From f3ea2130a45f06a1fbbd6e7f0364a32dcd7d3965 Mon Sep 17 00:00:00 2001 From: Jim Blandy Date: Thu, 16 Sep 2021 14:42:13 -0700 Subject: [PATCH] [wgsl-in]: Correctly compare pointer types. Treat `TypeInner::ValuePointer` and `TypeInner::Pointer` as equivalent by converting them to a canonical form before comparison. Support `ValuePointer` in WGSL type output. Fixes #1318. --- src/back/wgsl/writer.rs | 55 ++++++++++++++++++++++++++++++++++++++--- src/front/wgsl/mod.rs | 4 +-- src/proc/mod.rs | 53 +++++++++++++++++++++++++++++++++++++++ src/valid/compose.rs | 12 +++++++-- src/valid/function.rs | 15 +++++++++-- tests/wgsl-errors.rs | 18 ++++++++++++++ 6 files changed, 148 insertions(+), 9 deletions(-) diff --git a/src/back/wgsl/writer.rs b/src/back/wgsl/writer.rs index f8956a691b..fd2d1adfb5 100644 --- a/src/back/wgsl/writer.rs +++ b/src/back/wgsl/writer.rs @@ -630,17 +630,66 @@ impl Writer { } TypeInner::Pointer { base, class } => { let (storage, maybe_access) = storage_class_str(class); + // Everything but `StorageClass::Handle` gives us a `storage` name, but + // Naga IR never produces pointers to handles, so it doesn't matter much + // how we write such a type. Just write it as the base type alone. if let Some(class) = storage { write!(self.out, "ptr<{}, ", class)?; - if let Some(access) = maybe_access { - write!(self.out, ", {}", access)?; - } } self.write_type(module, base)?; if storage.is_some() { + if let Some(access) = maybe_access { + write!(self.out, ", {}", access)?; + } write!(self.out, ">")?; } } + TypeInner::ValuePointer { + size: None, + kind, + width: _, + class, + } => { + let (storage, maybe_access) = storage_class_str(class); + if let Some(class) = storage { + write!(self.out, "ptr<{}, {}", class, scalar_kind_str(kind))?; + if let Some(access) = maybe_access { + write!(self.out, ", {}", access)?; + } + write!(self.out, ">")?; + } else { + return Err(Error::Unimplemented(format!( + "ValuePointer to StorageClass::Handle {:?}", + inner + ))); + } + } + TypeInner::ValuePointer { + size: Some(size), + kind, + width: _, + class, + } => { + let (storage, maybe_access) = storage_class_str(class); + if let Some(class) = storage { + write!( + self.out, + "ptr<{}, vec{}<{}>", + class, + back::vector_size_str(size), + scalar_kind_str(kind) + )?; + if let Some(access) = maybe_access { + write!(self.out, ", {}", access)?; + } + write!(self.out, ">")?; + } else { + return Err(Error::Unimplemented(format!( + "ValuePointer to StorageClass::Handle {:?}", + inner + ))); + } + } _ => { return Err(Error::Unimplemented(format!( "write_value_type {:?}", diff --git a/src/front/wgsl/mod.rs b/src/front/wgsl/mod.rs index 6b13942a14..95b0e174c4 100644 --- a/src/front/wgsl/mod.rs +++ b/src/front/wgsl/mod.rs @@ -3231,7 +3231,7 @@ impl Parser { .resolve_type(expr_id)?; let expr_inner = context.typifier.get(expr_id, context.types); let given_inner = &context.types[ty].inner; - if given_inner != expr_inner { + if !given_inner.equivalent(expr_inner, context.types) { log::error!( "Given type {:?} doesn't match expected {:?}", given_inner, @@ -3292,7 +3292,7 @@ impl Parser { Some(ty) => { let expr_inner = context.typifier.get(value, context.types); let given_inner = &context.types[ty].inner; - if given_inner != expr_inner { + if !given_inner.equivalent(expr_inner, context.types) { log::error!( "Given type {:?} doesn't match expected {:?}", given_inner, diff --git a/src/proc/mod.rs b/src/proc/mod.rs index d672e905c4..fd6015e623 100644 --- a/src/proc/mod.rs +++ b/src/proc/mod.rs @@ -6,6 +6,8 @@ mod namer; mod terminator; mod typifier; +use std::cmp::PartialEq; + pub use index::IndexableLength; pub use layouter::{Alignment, InvalidBaseType, Layouter, TypeLayout}; pub use namer::{EntryPointIndex, NameKey, Namer}; @@ -129,6 +131,57 @@ impl super::TypeInner { Self::Image { .. } | Self::Sampler { .. } => 0, } } + + /// Return the canoncal form of `self`, or `None` if it's already in + /// canonical form. + /// + /// Certain types have multiple representations in `TypeInner`. This + /// function converts all forms of equivalent types to a single + /// representative of their class, so that simply applying `Eq` to the + /// result indicates whether the types are equivalent, as far as Naga IR is + /// concerned. + pub fn canonical_form( + &self, + types: &crate::UniqueArena, + ) -> Option { + use crate::TypeInner as Ti; + match *self { + Ti::Pointer { base, class } => match types[base].inner { + Ti::Scalar { kind, width } => Some(Ti::ValuePointer { + size: None, + kind, + width, + class, + }), + Ti::Vector { size, kind, width } => Some(Ti::ValuePointer { + size: Some(size), + kind, + width, + class, + }), + _ => None, + }, + _ => None, + } + } + + /// Compare `self` and `rhs` as types. + /// + /// This is mostly the same as `::eq`, but it treats + /// `ValuePointer` and `Pointer` types as equivalent. + /// + /// When you know that one side of the comparison is never a pointer, it's + /// fine to not bother with canonicalization, and just compare `TypeInner` + /// values with `==`. + pub fn equivalent( + &self, + rhs: &crate::TypeInner, + types: &crate::UniqueArena, + ) -> bool { + let left = self.canonical_form(types); + let right = rhs.canonical_form(types); + left.as_ref().unwrap_or(self) == right.as_ref().unwrap_or(rhs) + } } impl super::MathFunction { diff --git a/src/valid/compose.rs b/src/valid/compose.rs index 9f3aa82424..15148fc7ad 100644 --- a/src/valid/compose.rs +++ b/src/valid/compose.rs @@ -96,7 +96,11 @@ pub fn validate_compose( }); } for (index, comp_res) in component_resolutions.enumerate() { - if comp_res.inner_with(type_arena) != &type_arena[base].inner { + let base_inner = &type_arena[base].inner; + let comp_res_inner = comp_res.inner_with(type_arena); + // We don't support arrays of pointers, but it seems best not to + // embed that assumption here, so use `TypeInner::equivalent`. + if !base_inner.equivalent(comp_res_inner, type_arena) { log::error!("Array component[{}] type {:?}", index, comp_res); return Err(ComposeError::ComponentType { index: index as u32, @@ -113,7 +117,11 @@ pub fn validate_compose( } for (index, (member, comp_res)) in members.iter().zip(component_resolutions).enumerate() { - if comp_res.inner_with(type_arena) != &type_arena[member.ty].inner { + let member_inner = &type_arena[member.ty].inner; + let comp_res_inner = comp_res.inner_with(type_arena); + // We don't support pointers in structs, but it seems best not to embed + // that assumption here, so use `TypeInner::equivalent`. + if !comp_res_inner.equivalent(member_inner, type_arena) { log::error!("Struct component[{}] type {:?}", index, comp_res); return Err(ComposeError::ComponentType { index: index as u32, diff --git a/src/valid/function.rs b/src/valid/function.rs index 79917722bb..792b600e24 100644 --- a/src/valid/function.rs +++ b/src/valid/function.rs @@ -243,7 +243,8 @@ impl super::Validator { let ty = context .resolve_type_impl(expr, &self.valid_expression_set) .map_err(|error| CallError::Argument { index, error })?; - if ty != &context.types[arg.ty].inner { + let arg_inner = &context.types[arg.ty].inner; + if !ty.equivalent(arg_inner, context.types) { return Err(CallError::ArgumentType { index, required: arg.ty, @@ -448,7 +449,17 @@ impl super::Validator { .map(|expr| context.resolve_type(expr, &self.valid_expression_set)) .transpose()?; let expected_ty = context.return_type.map(|ty| &context.types[ty].inner); - if value_ty != expected_ty { + // We can't return pointers, but it seems best not to embed that + // assumption here, so use `TypeInner::equivalent` for comparison. + let okay = match (value_ty, expected_ty) { + (None, None) => true, + (Some(value_inner), Some(expected_inner)) => { + value_inner.equivalent(expected_inner, context.types) + } + (_, _) => false, + }; + + if !okay { log::error!( "Returning {:?} where {:?} is expected", value_ty, diff --git a/tests/wgsl-errors.rs b/tests/wgsl-errors.rs index 941865ab9d..f1b3fc55b0 100644 --- a/tests/wgsl-errors.rs +++ b/tests/wgsl-errors.rs @@ -687,6 +687,24 @@ fn invalid_functions() { } } +#[test] +fn pointer_type_equivalence() { + check_validation_error! { + r#" + fn f(pv: ptr>, pf: ptr) { } + + fn g() { + var m: mat2x2; + let pv: ptr> = &m.x; + let pf: ptr = &m.x.x; + + f(pv, pf); + } + "#: + Ok(_) + } +} + #[test] fn missing_bindings() { check_validation_error! {