diff --git a/src/front/spv/mod.rs b/src/front/spv/mod.rs index d6cedea182..e76b86c745 100644 --- a/src/front/spv/mod.rs +++ b/src/front/spv/mod.rs @@ -850,21 +850,9 @@ impl> Parser { let index_expr_data = &expressions[index_expr.handle]; let index_maybe = match *index_expr_data { crate::Expression::Constant(const_handle) => { - match const_arena[const_handle].inner { - crate::ConstantInner::Scalar { - width: _, - value: crate::ScalarValue::Uint(v), - } => Some(v as u32), - crate::ConstantInner::Scalar { - width: _, - value: crate::ScalarValue::Sint(v), - } => Some(v as u32), - _ => { - return Err(Error::InvalidAccess( - crate::Expression::Constant(const_handle), - )) - } - } + Some(const_arena[const_handle].to_array_length().ok_or( + Error::InvalidAccess(crate::Expression::Constant(const_handle)), + )?) } _ => None, }; diff --git a/src/proc/typifier.rs b/src/proc/typifier.rs index ce62a5b218..1af4cf8b59 100644 --- a/src/proc/typifier.rs +++ b/src/proc/typifier.rs @@ -62,6 +62,18 @@ impl Clone for TypeResolution { } } +impl crate::ConstantInner { + pub fn resolve_type(&self) -> TypeResolution { + match *self { + Self::Scalar { width, ref value } => TypeResolution::Value(crate::TypeInner::Scalar { + kind: value.scalar_kind(), + width, + }), + Self::Composite { ty, components: _ } => TypeResolution::Handle(ty), + } + } +} + #[derive(Clone, Debug, Error, PartialEq)] pub enum ResolveError { #[error("Index {index} is out of bounds for expression {expr:?}")] diff --git a/src/valid/compose.rs b/src/valid/compose.rs new file mode 100644 index 0000000000..e25a7379d0 --- /dev/null +++ b/src/valid/compose.rs @@ -0,0 +1,131 @@ +use crate::{ + arena::{Arena, Handle}, + proc::TypeResolution, +}; + +#[derive(Clone, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] +pub enum ComposeError { + #[error("Compose type {0:?} doesn't exist")] + TypeDoesntExist(Handle), + #[error("Composing of type {0:?} can't be done")] + Type(Handle), + #[error("Composing expects {expected} components but {given} were given")] + ComponentCount { given: u32, expected: u32 }, + #[error("Composing {index}'s component type is not expected")] + ComponentType { index: u32 }, +} + +pub fn validate_compose( + self_ty_handle: Handle, + constant_arena: &Arena, + type_arena: &Arena, + component_resolutions: impl ExactSizeIterator, +) -> Result<(), ComposeError> { + use crate::TypeInner as Ti; + + let self_ty = type_arena + .try_get(self_ty_handle) + .ok_or(ComposeError::TypeDoesntExist(self_ty_handle))?; + match self_ty.inner { + // vectors are composed from scalars or other vectors + Ti::Vector { size, kind, width } => { + let mut total = 0; + for (index, comp_res) in component_resolutions.enumerate() { + total += match *comp_res.inner_with(type_arena) { + Ti::Scalar { + kind: comp_kind, + width: comp_width, + } if comp_kind == kind && comp_width == width => 1, + Ti::Vector { + size: comp_size, + kind: comp_kind, + width: comp_width, + } if comp_kind == kind && comp_width == width => comp_size as u32, + ref other => { + log::error!("Vector component[{}] type {:?}", index, other); + return Err(ComposeError::ComponentType { + index: index as u32, + }); + } + }; + } + if size as u32 != total { + return Err(ComposeError::ComponentCount { + expected: size as u32, + given: total, + }); + } + } + // matrix are composed from column vectors + Ti::Matrix { + columns, + rows, + width, + } => { + let inner = Ti::Vector { + size: rows, + kind: crate::ScalarKind::Float, + width, + }; + if columns as usize != component_resolutions.len() { + return Err(ComposeError::ComponentCount { + expected: columns as u32, + given: component_resolutions.len() as u32, + }); + } + for (index, comp_res) in component_resolutions.enumerate() { + if comp_res.inner_with(type_arena) != &inner { + log::error!("Matrix component[{}] type {:?}", index, comp_res); + return Err(ComposeError::ComponentType { + index: index as u32, + }); + } + } + } + Ti::Array { + base, + size: crate::ArraySize::Constant(handle), + stride: _, + } => { + let count = constant_arena[handle].to_array_length().unwrap(); + if count as usize != component_resolutions.len() { + return Err(ComposeError::ComponentCount { + expected: count, + given: component_resolutions.len() as u32, + }); + } + for (index, comp_res) in component_resolutions.enumerate() { + if comp_res.inner_with(type_arena) != &type_arena[base].inner { + log::error!("Array component[{}] type {:?}", index, comp_res); + return Err(ComposeError::ComponentType { + index: index as u32, + }); + } + } + } + Ti::Struct { ref members, .. } => { + if members.len() != component_resolutions.len() { + return Err(ComposeError::ComponentCount { + given: component_resolutions.len() as u32, + expected: members.len() as u32, + }); + } + for (index, (member, comp_res)) in members.iter().zip(component_resolutions).enumerate() + { + if comp_res.inner_with(type_arena) != &type_arena[member.ty].inner { + log::error!("Struct component[{}] type {:?}", index, comp_res); + return Err(ComposeError::ComponentType { + index: index as u32, + }); + } + } + } + ref other => { + log::error!("Composing of {:?}", other); + return Err(ComposeError::Type(self_ty_handle)); + } + } + + Ok(()) +} diff --git a/src/valid/expression.rs b/src/valid/expression.rs index 9a1872d130..1ec603ce71 100644 --- a/src/valid/expression.rs +++ b/src/valid/expression.rs @@ -1,4 +1,4 @@ -use super::{FunctionInfo, ShaderStages, TypeFlags}; +use super::{compose::validate_compose, ComposeError, FunctionInfo, ShaderStages, TypeFlags}; use crate::{ arena::{Arena, Handle}, proc::ResolveError, @@ -33,14 +33,8 @@ pub enum ExpressionError { InvalidArrayType(Handle), #[error("Splatting {0:?} can't be done")] InvalidSplatType(Handle), - #[error("Compose type {0:?} doesn't exist")] - ComposeTypeDoesntExist(Handle), - #[error("Composing of type {0:?} can't be done")] - InvalidComposeType(Handle), - #[error("Composing expects {expected} components but {given} were given")] - InvalidComposeCount { given: u32, expected: u32 }, - #[error("Composing {0}'s component {1:?} is not expected")] - InvalidComponentType(u32, Handle), + #[error(transparent)] + Compose(#[from] ComposeError), #[error("Operation {0:?} can't work with {1:?}")] InvalidUnaryOperandType(crate::UnaryOperator, Handle), #[error("Operation {0:?} can't work with {1:?} and {2:?}")] @@ -207,117 +201,17 @@ impl super::Validator { } }, E::Compose { ref components, ty } => { - match module - .types - .try_get(ty) - .ok_or(ExpressionError::ComposeTypeDoesntExist(ty))? - .inner - { - // vectors are composed from scalars or other vectors - Ti::Vector { size, kind, width } => { - let mut total = 0; - for (index, &comp) in components.iter().enumerate() { - total += match *resolver.resolve(comp)? { - Ti::Scalar { - kind: comp_kind, - width: comp_width, - } if comp_kind == kind && comp_width == width => 1, - Ti::Vector { - size: comp_size, - kind: comp_kind, - width: comp_width, - } if comp_kind == kind && comp_width == width => comp_size as u32, - ref other => { - log::error!("Vector component[{}] type {:?}", index, other); - return Err(ExpressionError::InvalidComponentType( - index as u32, - comp, - )); - } - }; - } - if size as u32 != total { - return Err(ExpressionError::InvalidComposeCount { - expected: size as u32, - given: total, - }); - } - } - // matrix are composed from column vectors - Ti::Matrix { - columns, - rows, - width, - } => { - let inner = Ti::Vector { - size: rows, - kind: Sk::Float, - width, - }; - if columns as usize != components.len() { - return Err(ExpressionError::InvalidComposeCount { - expected: columns as u32, - given: components.len() as u32, - }); - } - for (index, &comp) in components.iter().enumerate() { - let tin = resolver.resolve(comp)?; - if tin != &inner { - log::error!("Matrix component[{}] type {:?}", index, tin); - return Err(ExpressionError::InvalidComponentType( - index as u32, - comp, - )); - } - } - } - Ti::Array { - base, - size: crate::ArraySize::Constant(handle), - stride: _, - } => { - let count = module.constants[handle].to_array_length().unwrap(); - if count as usize != components.len() { - return Err(ExpressionError::InvalidComposeCount { - expected: count, - given: components.len() as u32, - }); - } - let base_inner = &module.types[base].inner; - for (index, &comp) in components.iter().enumerate() { - let tin = resolver.resolve(comp)?; - if tin != base_inner { - log::error!("Array component[{}] type {:?}", index, tin); - return Err(ExpressionError::InvalidComponentType( - index as u32, - comp, - )); - } - } - } - Ti::Struct { ref members, .. } => { - for (index, (member, &comp)) in members.iter().zip(components).enumerate() { - let tin = resolver.resolve(comp)?; - if tin != &module.types[member.ty].inner { - log::error!("Struct component[{}] type {:?}", index, tin); - return Err(ExpressionError::InvalidComponentType( - index as u32, - comp, - )); - } - } - if members.len() != components.len() { - return Err(ExpressionError::InvalidComposeCount { - given: components.len() as u32, - expected: members.len() as u32, - }); - } - } - ref other => { - log::error!("Composing of {:?}", other); - return Err(ExpressionError::InvalidComposeType(ty)); + for &handle in components { + if handle >= root { + return Err(ExpressionError::ForwardDependency(handle)); } } + validate_compose( + ty, + &module.constants, + &module.types, + components.iter().map(|&handle| info[handle].ty.clone()), + )?; ShaderStages::all() } E::FunctionArgument(index) => { diff --git a/src/valid/mod.rs b/src/valid/mod.rs index 87bfd20885..9ef3bb3121 100644 --- a/src/valid/mod.rs +++ b/src/valid/mod.rs @@ -1,4 +1,5 @@ mod analyzer; +mod compose; mod expression; mod function; mod interface; @@ -15,6 +16,7 @@ use std::ops; // merge the corresponding matches over expressions and statements. pub use analyzer::{ExpressionInfo, FunctionInfo, GlobalUse, Uniformity, UniformityRequirements}; +pub use compose::ComposeError; pub use expression::ExpressionError; pub use function::{CallError, FunctionError, LocalVariableError}; pub use interface::{EntryPointError, GlobalVariableError, VaryingError}; @@ -33,6 +35,8 @@ bitflags::bitflags! { const CONTROL_FLOW_UNIFORMITY = 0x4; /// Host-shareable structure layouts. const STRUCT_LAYOUTS = 0x8; + /// Constants. + const CONSTANTS = 0x10; } } @@ -81,6 +85,8 @@ pub enum ConstantError { UnresolvedComponent(Handle), #[error("The array size handle {0:?} can not be resolved")] UnresolvedSize(Handle), + #[error(transparent)] + Compose(#[from] ComposeError), } #[derive(Clone, Debug, thiserror::Error)] @@ -191,25 +197,25 @@ impl Validator { } crate::ConstantInner::Composite { ty, ref components } => { match types[ty].inner { - crate::TypeInner::Array { - size: crate::ArraySize::Dynamic, - .. - } => { - return Err(ConstantError::InvalidType); - } crate::TypeInner::Array { size: crate::ArraySize::Constant(size_handle), .. - } => { - if handle <= size_handle { - return Err(ConstantError::UnresolvedSize(size_handle)); - } + } if handle <= size_handle => { + return Err(ConstantError::UnresolvedSize(size_handle)); } - _ => {} //TODO + _ => {} } if let Some(&comp) = components.iter().find(|&&comp| handle <= comp) { return Err(ConstantError::UnresolvedComponent(comp)); } + compose::validate_compose( + ty, + constants, + types, + components + .iter() + .map(|&component| constants[component].inner.resolve_type()), + )?; } } Ok(()) @@ -219,13 +225,15 @@ impl Validator { pub fn validate(&mut self, module: &crate::Module) -> Result { self.reset_types(module.types.len()); - for (handle, constant) in module.constants.iter() { - self.validate_constant(handle, &module.constants, &module.types) - .map_err(|error| ValidationError::Constant { - handle, - name: constant.name.clone().unwrap_or_default(), - error, - })?; + if self.flags.contains(ValidationFlags::CONSTANTS) { + for (handle, constant) in module.constants.iter() { + self.validate_constant(handle, &module.constants, &module.types) + .map_err(|error| ValidationError::Constant { + handle, + name: constant.name.clone().unwrap_or_default(), + error, + })?; + } } // doing after the globals, so that `type_flags` is ready diff --git a/tests/out/collatz.info.ron b/tests/out/collatz.info.ron index 6bea486301..edf6a5a5a4 100644 --- a/tests/out/collatz.info.ron +++ b/tests/out/collatz.info.ron @@ -2,7 +2,7 @@ functions: [ ( flags: ( - bits: 15, + bits: 31, ), available_stages: ( bits: 7, @@ -341,7 +341,7 @@ entry_points: [ ( flags: ( - bits: 15, + bits: 31, ), available_stages: ( bits: 7, diff --git a/tests/out/shadow.info.ron b/tests/out/shadow.info.ron index 6e11adbc9b..d1491f6c8c 100644 --- a/tests/out/shadow.info.ron +++ b/tests/out/shadow.info.ron @@ -2,7 +2,7 @@ functions: [ ( flags: ( - bits: 15, + bits: 31, ), available_stages: ( bits: 7, @@ -1063,7 +1063,7 @@ ), ( flags: ( - bits: 15, + bits: 31, ), available_stages: ( bits: 7, @@ -2753,7 +2753,7 @@ entry_points: [ ( flags: ( - bits: 15, + bits: 31, ), available_stages: ( bits: 7,