Check host shared decorations in the validator

This commit is contained in:
Dzmitry Malyshau
2021-01-29 13:05:27 -05:00
committed by Timo de Kort
parent 0c5db60d69
commit ea64ab0431
2 changed files with 87 additions and 48 deletions

View File

@@ -118,7 +118,7 @@ impl<T> Arena<T> {
/// Returns an iterator over the items stored in this arena, returning both
/// the item's handle and a reference to it.
pub fn iter(&self) -> impl Iterator<Item = (Handle<T>, &T)> {
pub fn iter(&self) -> impl DoubleEndedIterator<Item = (Handle<T>, &T)> {
self.data.iter().enumerate().map(|(i, v)| {
let position = i + 1;
let index = unsafe { Index::new_unchecked(position as u32) };

View File

@@ -4,11 +4,20 @@ use bit_set::BitSet;
const MAX_WORKGROUP_SIZE: u32 = 0x4000;
bitflags::bitflags! {
#[repr(transparent)]
struct TypeFlags: u8 {
const INTERFACE = 1;
const HOST_SHARED = 2;
}
}
#[derive(Debug)]
pub struct Validator {
//Note: this is a bit tricky: some of the front-ends as well as backends
// already have to use the typifier, so the work here is redundant in a way.
typifier: Typifier,
type_flags: Vec<TypeFlags>,
location_in_mask: BitSet,
location_out_mask: BitSet,
bind_group_masks: Vec<BitSet>,
@@ -22,6 +31,10 @@ pub enum TypeError {
UnresolvedBase(Handle<crate::Type>),
#[error("The constant {0:?} can not be used for an array size")]
InvalidArraySizeConstant(Handle<crate::Constant>),
#[error("Array doesn't have a stride decoration, can't be host-shared")]
MissingStrideDecoration,
#[error("Structure doesn't have a block decoration, can't be host-shared")]
MissingBlockDecoration,
}
#[derive(Clone, Debug, PartialEq, thiserror::Error)]
@@ -40,8 +53,6 @@ pub enum GlobalVariableError {
InvalidUsage,
#[error("Type isn't compatible with the storage class")]
InvalidType,
#[error("Structure doesn't have a block decoration, can't be host-shared")]
MissingBlockDecoration,
#[error("Interpolation is not valid")]
InvalidInterpolation,
#[error("Storage access {seen:?} exceed the allowed {allowed:?}")]
@@ -224,22 +235,6 @@ impl crate::GlobalVariable {
}
}
impl crate::TypeInner {
fn check_block(&self) -> Result<(), GlobalVariableError> {
match *self {
Self::Struct {
block: true,
members: _,
} => Ok(()),
Self::Struct {
block: false,
members: _,
} => Err(GlobalVariableError::MissingBlockDecoration),
_ => Err(GlobalVariableError::InvalidType),
}
}
}
fn storage_usage(access: crate::StorageAccess) -> crate::GlobalUse {
let mut storage_usage = crate::GlobalUse::empty();
if access.contains(crate::StorageAccess::LOAD) {
@@ -280,6 +275,7 @@ impl Validator {
pub fn new() -> Self {
Validator {
typifier: Typifier::new(),
type_flags: Vec::new(),
location_in_mask: BitSet::new(),
location_out_mask: BitSet::new(),
bind_group_masks: Vec::new(),
@@ -293,6 +289,29 @@ impl Validator {
}
}
fn fill_type_flags(&mut self, arena: &Arena<crate::Type>) {
for (handle, ty) in arena.iter().rev() {
let flags = self.type_flags[handle.index()];
match ty.inner {
crate::TypeInner::Array { base, .. } => {
//Note: don't assume anything about the indices,
// they are checked in `validate_type` later on.
if let Some(f) = self.type_flags.get_mut(base.index()) {
*f |= flags;
}
}
crate::TypeInner::Struct { ref members, .. } => {
for member in members {
if let Some(f) = self.type_flags.get_mut(member.ty.index()) {
*f |= flags;
}
}
}
_ => {}
}
}
}
fn validate_type(
&self,
ty: &crate::Type,
@@ -316,7 +335,7 @@ impl Validator {
return Err(TypeError::UnresolvedBase(base));
}
}
Ti::Array { base, size, .. } => {
Ti::Array { base, size, stride } => {
if base >= handle {
return Err(TypeError::UnresolvedBase(base));
}
@@ -335,12 +354,18 @@ impl Validator {
}
}
}
//TODO: check stride
if stride.is_none()
&& self.type_flags[handle.index()].contains(TypeFlags::HOST_SHARED)
{
return Err(TypeError::MissingStrideDecoration);
}
}
Ti::Struct {
block: _,
ref members,
} => {
//TODO: check the offsets
Ti::Struct { block, ref members } => {
if !block && self.type_flags[handle.index()].contains(TypeFlags::HOST_SHARED) {
return Err(TypeError::MissingBlockDecoration);
}
//TODO: check the spans
for member in members {
if member.ty >= handle {
return Err(TypeError::UnresolvedBase(member.ty));
@@ -396,42 +421,47 @@ impl Validator {
&self,
var: &crate::GlobalVariable,
types: &Arena<crate::Type>,
) -> Result<(), GlobalVariableError> {
) -> Result<TypeFlags, GlobalVariableError> {
log::debug!("var {:?}", var);
let allowed_storage_access = match var.class {
let (allowed_storage_access, type_flags) = match var.class {
crate::StorageClass::Function => return Err(GlobalVariableError::InvalidUsage),
crate::StorageClass::Input | crate::StorageClass::Output => {
var.check_varying(types)?;
crate::StorageAccess::empty()
(crate::StorageAccess::empty(), TypeFlags::INTERFACE)
}
crate::StorageClass::Storage => {
var.check_resource()?;
let ty = &types[var.ty];
ty.inner.check_block()?;
crate::StorageAccess::all()
match types[var.ty].inner {
crate::TypeInner::Struct { .. } => (),
_ => return Err(GlobalVariableError::InvalidType),
}
(crate::StorageAccess::all(), TypeFlags::HOST_SHARED)
}
crate::StorageClass::Uniform => {
var.check_resource()?;
let ty = &types[var.ty];
ty.inner.check_block()?;
crate::StorageAccess::empty()
match types[var.ty].inner {
crate::TypeInner::Struct { .. } => (),
_ => return Err(GlobalVariableError::InvalidType),
}
(crate::StorageAccess::empty(), TypeFlags::HOST_SHARED)
}
crate::StorageClass::Handle => {
var.check_resource()?;
match types[var.ty].inner {
let allowed_access = match types[var.ty].inner {
crate::TypeInner::Image {
class: crate::ImageClass::Storage(_),
..
} => crate::StorageAccess::all(),
_ => crate::StorageAccess::empty(),
}
};
(allowed_access, TypeFlags::empty())
}
crate::StorageClass::Private | crate::StorageClass::WorkGroup => {
if var.binding.is_some() {
return Err(GlobalVariableError::InvalidBinding);
}
var.forbid_interpolation()?;
crate::StorageAccess::empty()
(crate::StorageAccess::empty(), TypeFlags::empty())
}
crate::StorageClass::PushConstant => {
//TODO
@@ -449,7 +479,7 @@ impl Validator {
});
}
Ok(())
Ok(type_flags)
}
fn validate_local_var(
@@ -665,15 +695,9 @@ impl Validator {
/// Check the given module to be valid.
pub fn validate(&mut self, module: &crate::Module) -> Result<(), ValidationError> {
self.typifier.clear();
for (handle, ty) in module.types.iter() {
self.validate_type(ty, handle, &module.constants)
.map_err(|error| ValidationError::Type {
handle,
name: ty.name.clone().unwrap_or_default(),
error,
})?;
}
self.type_flags.clear();
self.type_flags
.resize(module.types.len(), TypeFlags::empty());
for (handle, constant) in module.constants.iter() {
self.validate_constant(handle, &module.constants, &module.types)
@@ -683,13 +707,28 @@ impl Validator {
error,
})?;
}
for (var_handle, var) in module.global_variables.iter() {
self.validate_global_var(var, &module.types)
let ty_flags = self
.validate_global_var(var, &module.types)
.map_err(|error| ValidationError::GlobalVariable {
handle: var_handle,
name: var.name.clone().unwrap_or_default(),
error,
})?;
self.type_flags[var.ty.index()] |= ty_flags;
}
self.fill_type_flags(&module.types);
// doing after the globals, so that `type_flags` is ready
for (handle, ty) in module.types.iter() {
self.validate_type(ty, handle, &module.constants)
.map_err(|error| ValidationError::Type {
handle,
name: ty.name.clone().unwrap_or_default(),
error,
})?;
}
for (fun_handle, fun) in module.functions.iter() {