diff --git a/src/valid/expression.rs b/src/valid/expression.rs index c996c9dbcb..7128a0c230 100644 --- a/src/valid/expression.rs +++ b/src/valid/expression.rs @@ -75,7 +75,7 @@ pub enum ExpressionError { InvalidImageArrayIndexType(Handle), #[error("Image other index type of {0:?} is not an integer scalar")] InvalidImageOtherIndexType(Handle), - #[error("Image coordinate index type of {1:?} does not match dimension {0:?}")] + #[error("Image coordinate type of {1:?} does not match dimension {0:?}")] InvalidImageCoordinateType(crate::ImageDimension, Handle), #[error("Comparison sampling mismatch: image has class {image:?}, but the sampler is comparison={sampler}, and the reference was provided={has_ref}")] ComparisonSamplingMismatch { @@ -543,34 +543,14 @@ impl super::Validator { arrayed, dim, } => { - match (dim, resolver.resolve(coordinate)?) { - ( - crate::ImageDimension::D1, - &Ti::Scalar { - kind: crate::ScalarKind::Sint, - .. - }, - ) - | ( - crate::ImageDimension::D2, - &Ti::Vector { - kind: crate::ScalarKind::Sint, - .. - }, - ) - | ( - crate::ImageDimension::D3, - &Ti::Vector { - kind: crate::ScalarKind::Sint, - .. - }, - ) => {} + match resolver.resolve(coordinate)?.image_storage_coordinates() { + Some(coord_dim) if coord_dim == dim => {} _ => { return Err(ExpressionError::InvalidImageCoordinateType( dim, coordinate, )) } - } + }; let needs_index = match class { crate::ImageClass::Storage { .. } => false, _ => true, diff --git a/src/valid/function.rs b/src/valid/function.rs index 50653a78d9..081b2decd9 100644 --- a/src/valid/function.rs +++ b/src/valid/function.rs @@ -83,12 +83,10 @@ pub enum FunctionError { pointer: Handle, value: Handle, }, - #[error("The image array can't be indexed by {0:?}")] - InvalidArrayIndex(Handle), #[error("The expression {0:?} is currupted")] InvalidExpression(Handle), - #[error("The expression {0:?} is not an image")] - InvalidImage(Handle), + #[error("Image store parameters are invalid")] + InvalidImageStore(#[source] ExpressionError), #[error("Call to {function:?} is invalid")] InvalidCall { function: Handle, @@ -120,6 +118,7 @@ struct BlockContext<'a> { info: &'a FunctionInfo, expressions: &'a Arena, types: &'a Arena, + global_vars: &'a Arena, functions: &'a Arena, return_type: Option>, } @@ -131,6 +130,7 @@ impl<'a> BlockContext<'a> { info, expressions: &fun.expressions, types: &module.types, + global_vars: &module.global_variables, functions: &module.functions, return_type: fun.result.as_ref().map(|fr| fr.ty), } @@ -142,6 +142,7 @@ impl<'a> BlockContext<'a> { info: self.info, expressions: self.expressions, types: self.types, + global_vars: self.global_vars, functions: self.functions, return_type: self.return_type, } @@ -329,7 +330,7 @@ impl super::Validator { let mut current = pointer; loop { let _ = context.resolve_type(current)?; - match context.expressions[current] { + match *context.get_expression(current)? { crate::Expression::Access { base, .. } | crate::Expression::AccessIndex { base, .. } => current = base, crate::Expression::LocalVariable(_) @@ -368,28 +369,82 @@ impl super::Validator { } S::ImageStore { image, - coordinate: _, + coordinate, array_index, value, } => { - let _expected_coordinate_ty = match *context.get_expression(image)? { - crate::Expression::GlobalVariable(_var_handle) => (), //TODO - _ => return Err(FunctionError::InvalidImage(image)), - }; - match *context.resolve_type(value)? { - Ti::Scalar { .. } | Ti::Vector { .. } => {} + //Note: this code uses a lot of `FunctionError::InvalidImageStore`, + // and could probably be refactored. + let var = match *context.get_expression(image)? { + crate::Expression::GlobalVariable(var_handle) => { + &context.global_vars[var_handle] + } _ => { - return Err(FunctionError::InvalidStoreValue(value)); + return Err(FunctionError::InvalidImageStore( + ExpressionError::ExpectedGlobalVariable, + )) } - } - if let Some(expr) = array_index { - match *context.resolve_type(expr)? { - Ti::Scalar { - kind: crate::ScalarKind::Sint, - width: _, - } => (), - _ => return Err(FunctionError::InvalidArrayIndex(expr)), + }; + + let value_ty = match context.types[var.ty].inner { + Ti::Image { + class, + arrayed, + dim, + } => { + match context + .resolve_type(coordinate)? + .image_storage_coordinates() + { + Some(coord_dim) if coord_dim == dim => {} + _ => { + return Err(FunctionError::InvalidImageStore( + ExpressionError::InvalidImageCoordinateType( + dim, coordinate, + ), + )) + } + }; + if arrayed != array_index.is_some() { + return Err(FunctionError::InvalidImageStore( + ExpressionError::InvalidImageArrayIndex, + )); + } + if let Some(expr) = array_index { + match *context.resolve_type(expr)? { + Ti::Scalar { + kind: crate::ScalarKind::Sint, + width: _, + } => {} + _ => { + return Err(FunctionError::InvalidImageStore( + ExpressionError::InvalidImageArrayIndexType(expr), + )) + } + } + } + match class { + crate::ImageClass::Storage(format) => crate::TypeInner::Vector { + kind: format.into(), + size: crate::VectorSize::Quad, + width: 4, + }, + _ => { + return Err(FunctionError::InvalidImageStore( + ExpressionError::InvalidImageClass(class), + )) + } + } } + _ => { + return Err(FunctionError::InvalidImageStore( + ExpressionError::ExpectedImageType(var.ty), + )) + } + }; + + if *context.resolve_type(value)? != value_ty { + return Err(FunctionError::InvalidStoreValue(value)); } } S::Call { diff --git a/src/valid/mod.rs b/src/valid/mod.rs index aff2714756..0d316e1f68 100644 --- a/src/valid/mod.rs +++ b/src/valid/mod.rs @@ -136,6 +136,26 @@ impl crate::TypeInner { Self::Array { .. } | Self::Image { .. } | Self::Sampler { .. } => false, } } + + fn image_storage_coordinates(&self) -> Option { + match *self { + Self::Scalar { + kind: crate::ScalarKind::Sint, + .. + } => Some(crate::ImageDimension::D1), + Self::Vector { + size: crate::VectorSize::Bi, + kind: crate::ScalarKind::Sint, + .. + } => Some(crate::ImageDimension::D2), + Self::Vector { + size: crate::VectorSize::Tri, + kind: crate::ScalarKind::Sint, + .. + } => Some(crate::ImageDimension::D3), + _ => None, + } + } } impl Validator {