diff --git a/src/proc/layouter.rs b/src/proc/layouter.rs index ce798bc884..280305bcd4 100644 --- a/src/proc/layouter.rs +++ b/src/proc/layouter.rs @@ -88,21 +88,9 @@ impl Layouter { }, Ti::Array { base, size, stride } => { let count = match size { - crate::ArraySize::Constant(handle) => match constants[handle].inner { - crate::ConstantInner::Scalar { - width: _, - value: crate::ScalarValue::Uint(value), - } => value as u32, - // Accept a signed integer size to avoid - // requiring an explicit uint - // literal. Type inference should make - // this unnecessary. - crate::ConstantInner::Scalar { - width: _, - value: crate::ScalarValue::Sint(value), - } => value as u32, - ref other => unreachable!("Unexpected array size {:?}", other), - }, + crate::ArraySize::Constant(handle) => { + constants[handle].to_array_length().unwrap() + } // A dynamically-sized array has to have at least one element crate::ArraySize::Dynamic => 1, }; diff --git a/src/proc/mod.rs b/src/proc/mod.rs index aa383fcb02..c4ad803af3 100644 --- a/src/proc/mod.rs +++ b/src/proc/mod.rs @@ -158,3 +158,22 @@ impl crate::SampleLevel { } } } + +impl crate::Constant { + pub fn to_array_length(&self) -> Option { + use std::convert::TryInto; + match self.inner { + crate::ConstantInner::Scalar { value, width: _ } => match value { + crate::ScalarValue::Uint(value) => value.try_into().ok(), + // Accept a signed integer size to avoid + // requiring an explicit uint + // literal. Type inference should make + // this unnecessary. + crate::ScalarValue::Sint(value) => value.try_into().ok(), + _ => None, + }, + // caught by type validation + crate::ConstantInner::Composite { .. } => None, + } + } +} diff --git a/src/proc/validator.rs b/src/proc/validator.rs index 4c5a804e4d..49466d83b0 100644 --- a/src/proc/validator.rs +++ b/src/proc/validator.rs @@ -173,6 +173,14 @@ pub enum ExpressionError { InvalidPointerType(Handle), #[error("Array length of {0:?} can't be done")] InvalidArrayType(Handle), + #[error("Compose type {0:?} doesn't exist")] + ComposeTypeDoesntExist(Handle), + #[error("Composing of type {0:?} can't be done")] + InvalidComposeType(Handle), + #[error("Composing expects {expected} components but {given} were given")] + InvalidComposeCount { given: u32, expected: u32 }, + #[error("Composing {0}'s component {1:?} is not expected")] + InvalidComponentType(u32, Handle), } #[derive(Clone, Debug, Error)] @@ -880,7 +888,6 @@ impl Validator { module: &crate::Module, ) -> Result<(), ExpressionError> { use crate::{Expression as E, TypeInner as Ti}; - use std::convert::TryInto; let resolver = ExpressionTypeResolver { root, @@ -923,24 +930,7 @@ impl Validator { Ti::Array { size: crate::ArraySize::Constant(handle), .. - } => { - match module - .constants - .try_get(handle) - .ok_or(ExpressionError::ConstantDoesntExist(handle))? - .inner - { - crate::ConstantInner::Scalar { value, width: _ } => match value { - crate::ScalarValue::Sint(value) => { - value.max(0).try_into().unwrap_or(!0) - } - crate::ScalarValue::Uint(value) => value.try_into().unwrap_or(!0), - _ => unreachable!(), - }, - // caught by type validation - crate::ConstantInner::Composite { .. } => unreachable!(), - } - } + } => module.constants[handle].to_array_length().unwrap(), Ti::Array { .. } => !0, // can't statically know, but need run-time checks Ti::Pointer { .. } => !0, //TODO Ti::Struct { @@ -963,7 +953,121 @@ impl Validator { .ok_or(ExpressionError::ConstantDoesntExist(handle))?; } E::Compose { ref components, ty } => { - //TODO + match module + .types + .try_get(ty) + .ok_or(ExpressionError::ComposeTypeDoesntExist(ty))? + .inner + { + // vectors are composed from scalars or other vectors + Ti::Vector { size, kind, width } => { + let inner = Ti::Scalar { kind, width }; + let mut total = 0; + for (index, &comp) in components.iter().enumerate() { + total += match *resolver.resolve(comp)? { + Ti::Scalar { + kind: comp_kind, + width: comp_width, + } if comp_kind == kind && comp_width == width => 1, + Ti::Vector { + size: comp_size, + kind: comp_kind, + width: comp_width, + } if comp_kind == kind && comp_width == width => comp_size as u32, + ref other => { + log::error!("Vector component[{}] type {:?}", index, other); + return Err(ExpressionError::InvalidComponentType( + index as u32, + comp, + )); + } + }; + } + if size as u32 != total { + return Err(ExpressionError::InvalidComposeCount { + expected: size as u32, + given: total, + }); + } + } + // matrix are composed from column vectors + Ti::Matrix { + columns, + rows, + width, + } => { + let inner = Ti::Vector { + size: rows, + kind: crate::ScalarKind::Float, + width, + }; + if columns as usize != components.len() { + return Err(ExpressionError::InvalidComposeCount { + expected: columns as u32, + given: components.len() as u32, + }); + } + for (index, &comp) in components.iter().enumerate() { + let tin = resolver.resolve(comp)?; + if tin != &inner { + log::error!("Matrix component[{}] type {:?}", index, tin); + return Err(ExpressionError::InvalidComponentType( + index as u32, + comp, + )); + } + } + } + Ti::Array { + base, + size: crate::ArraySize::Constant(handle), + stride: _, + } => { + let count = module.constants[handle].to_array_length().unwrap(); + if count as usize != components.len() { + return Err(ExpressionError::InvalidComposeCount { + expected: count, + given: components.len() as u32, + }); + } + let base_inner = &module.types[base].inner; + for (index, &comp) in components.iter().enumerate() { + let tin = resolver.resolve(comp)?; + if tin != base_inner { + log::error!("Array component[{}] type {:?}", index, tin); + return Err(ExpressionError::InvalidComponentType( + index as u32, + comp, + )); + } + } + } + Ti::Struct { + block: _, + ref members, + } => { + for (index, (member, &comp)) in members.iter().zip(components).enumerate() { + let tin = resolver.resolve(comp)?; + if tin != &module.types[member.ty].inner { + log::error!("Struct component[{}] type {:?}", index, tin); + return Err(ExpressionError::InvalidComponentType( + index as u32, + comp, + )); + } + } + if members.len() != components.len() { + return Err(ExpressionError::InvalidComposeCount { + given: components.len() as u32, + expected: members.len() as u32, + }); + } + } + ref other => { + log::error!("Composing of {:?}", other); + return Err(ExpressionError::InvalidComposeType(ty)); + } + } } E::FunctionArgument(index) => { if index >= function.arguments.len() as u32 {