diff --git a/src/arena.rs b/src/arena.rs index 460d909c2d..0f673fa684 100644 --- a/src/arena.rs +++ b/src/arena.rs @@ -118,7 +118,7 @@ impl Arena { /// Returns an iterator over the items stored in this arena, returning both /// the item's handle and a reference to it. - pub fn iter(&self) -> impl Iterator, &T)> { + pub fn iter(&self) -> impl DoubleEndedIterator, &T)> { self.data.iter().enumerate().map(|(i, v)| { let position = i + 1; let index = unsafe { Index::new_unchecked(position as u32) }; diff --git a/src/proc/validator.rs b/src/proc/validator.rs index f6b1f153f2..934ffabf29 100644 --- a/src/proc/validator.rs +++ b/src/proc/validator.rs @@ -4,11 +4,20 @@ use bit_set::BitSet; const MAX_WORKGROUP_SIZE: u32 = 0x4000; +bitflags::bitflags! { + #[repr(transparent)] + struct TypeFlags: u8 { + const INTERFACE = 1; + const HOST_SHARED = 2; + } +} + #[derive(Debug)] pub struct Validator { //Note: this is a bit tricky: some of the front-ends as well as backends // already have to use the typifier, so the work here is redundant in a way. typifier: Typifier, + type_flags: Vec, location_in_mask: BitSet, location_out_mask: BitSet, bind_group_masks: Vec, @@ -22,6 +31,10 @@ pub enum TypeError { UnresolvedBase(Handle), #[error("The constant {0:?} can not be used for an array size")] InvalidArraySizeConstant(Handle), + #[error("Array doesn't have a stride decoration, can't be host-shared")] + MissingStrideDecoration, + #[error("Structure doesn't have a block decoration, can't be host-shared")] + MissingBlockDecoration, } #[derive(Clone, Debug, PartialEq, thiserror::Error)] @@ -40,8 +53,6 @@ pub enum GlobalVariableError { InvalidUsage, #[error("Type isn't compatible with the storage class")] InvalidType, - #[error("Structure doesn't have a block decoration, can't be host-shared")] - MissingBlockDecoration, #[error("Interpolation is not valid")] InvalidInterpolation, #[error("Storage access {seen:?} exceed the allowed {allowed:?}")] @@ -224,22 +235,6 @@ impl crate::GlobalVariable { } } -impl crate::TypeInner { - fn check_block(&self) -> Result<(), GlobalVariableError> { - match *self { - Self::Struct { - block: true, - members: _, - } => Ok(()), - Self::Struct { - block: false, - members: _, - } => Err(GlobalVariableError::MissingBlockDecoration), - _ => Err(GlobalVariableError::InvalidType), - } - } -} - fn storage_usage(access: crate::StorageAccess) -> crate::GlobalUse { let mut storage_usage = crate::GlobalUse::empty(); if access.contains(crate::StorageAccess::LOAD) { @@ -280,6 +275,7 @@ impl Validator { pub fn new() -> Self { Validator { typifier: Typifier::new(), + type_flags: Vec::new(), location_in_mask: BitSet::new(), location_out_mask: BitSet::new(), bind_group_masks: Vec::new(), @@ -293,6 +289,29 @@ impl Validator { } } + fn fill_type_flags(&mut self, arena: &Arena) { + for (handle, ty) in arena.iter().rev() { + let flags = self.type_flags[handle.index()]; + match ty.inner { + crate::TypeInner::Array { base, .. } => { + //Note: don't assume anything about the indices, + // they are checked in `validate_type` later on. + if let Some(f) = self.type_flags.get_mut(base.index()) { + *f |= flags; + } + } + crate::TypeInner::Struct { ref members, .. } => { + for member in members { + if let Some(f) = self.type_flags.get_mut(member.ty.index()) { + *f |= flags; + } + } + } + _ => {} + } + } + } + fn validate_type( &self, ty: &crate::Type, @@ -316,7 +335,7 @@ impl Validator { return Err(TypeError::UnresolvedBase(base)); } } - Ti::Array { base, size, .. } => { + Ti::Array { base, size, stride } => { if base >= handle { return Err(TypeError::UnresolvedBase(base)); } @@ -335,12 +354,18 @@ impl Validator { } } } + //TODO: check stride + if stride.is_none() + && self.type_flags[handle.index()].contains(TypeFlags::HOST_SHARED) + { + return Err(TypeError::MissingStrideDecoration); + } } - Ti::Struct { - block: _, - ref members, - } => { - //TODO: check the offsets + Ti::Struct { block, ref members } => { + if !block && self.type_flags[handle.index()].contains(TypeFlags::HOST_SHARED) { + return Err(TypeError::MissingBlockDecoration); + } + //TODO: check the spans for member in members { if member.ty >= handle { return Err(TypeError::UnresolvedBase(member.ty)); @@ -396,42 +421,47 @@ impl Validator { &self, var: &crate::GlobalVariable, types: &Arena, - ) -> Result<(), GlobalVariableError> { + ) -> Result { log::debug!("var {:?}", var); - let allowed_storage_access = match var.class { + let (allowed_storage_access, type_flags) = match var.class { crate::StorageClass::Function => return Err(GlobalVariableError::InvalidUsage), crate::StorageClass::Input | crate::StorageClass::Output => { var.check_varying(types)?; - crate::StorageAccess::empty() + (crate::StorageAccess::empty(), TypeFlags::INTERFACE) } crate::StorageClass::Storage => { var.check_resource()?; - let ty = &types[var.ty]; - ty.inner.check_block()?; - crate::StorageAccess::all() + match types[var.ty].inner { + crate::TypeInner::Struct { .. } => (), + _ => return Err(GlobalVariableError::InvalidType), + } + (crate::StorageAccess::all(), TypeFlags::HOST_SHARED) } crate::StorageClass::Uniform => { var.check_resource()?; - let ty = &types[var.ty]; - ty.inner.check_block()?; - crate::StorageAccess::empty() + match types[var.ty].inner { + crate::TypeInner::Struct { .. } => (), + _ => return Err(GlobalVariableError::InvalidType), + } + (crate::StorageAccess::empty(), TypeFlags::HOST_SHARED) } crate::StorageClass::Handle => { var.check_resource()?; - match types[var.ty].inner { + let allowed_access = match types[var.ty].inner { crate::TypeInner::Image { class: crate::ImageClass::Storage(_), .. } => crate::StorageAccess::all(), _ => crate::StorageAccess::empty(), - } + }; + (allowed_access, TypeFlags::empty()) } crate::StorageClass::Private | crate::StorageClass::WorkGroup => { if var.binding.is_some() { return Err(GlobalVariableError::InvalidBinding); } var.forbid_interpolation()?; - crate::StorageAccess::empty() + (crate::StorageAccess::empty(), TypeFlags::empty()) } crate::StorageClass::PushConstant => { //TODO @@ -449,7 +479,7 @@ impl Validator { }); } - Ok(()) + Ok(type_flags) } fn validate_local_var( @@ -665,15 +695,9 @@ impl Validator { /// Check the given module to be valid. pub fn validate(&mut self, module: &crate::Module) -> Result<(), ValidationError> { self.typifier.clear(); - - for (handle, ty) in module.types.iter() { - self.validate_type(ty, handle, &module.constants) - .map_err(|error| ValidationError::Type { - handle, - name: ty.name.clone().unwrap_or_default(), - error, - })?; - } + self.type_flags.clear(); + self.type_flags + .resize(module.types.len(), TypeFlags::empty()); for (handle, constant) in module.constants.iter() { self.validate_constant(handle, &module.constants, &module.types) @@ -683,13 +707,28 @@ impl Validator { error, })?; } + for (var_handle, var) in module.global_variables.iter() { - self.validate_global_var(var, &module.types) + let ty_flags = self + .validate_global_var(var, &module.types) .map_err(|error| ValidationError::GlobalVariable { handle: var_handle, name: var.name.clone().unwrap_or_default(), error, })?; + self.type_flags[var.ty.index()] |= ty_flags; + } + + self.fill_type_flags(&module.types); + + // doing after the globals, so that `type_flags` is ready + for (handle, ty) in module.types.iter() { + self.validate_type(ty, handle, &module.constants) + .map_err(|error| ValidationError::Type { + handle, + name: ty.name.clone().unwrap_or_default(), + error, + })?; } for (fun_handle, fun) in module.functions.iter() {