diff --git a/src/valid/expression.rs b/src/valid/expression.rs index c2d18d4158..c996c9dbcb 100644 --- a/src/valid/expression.rs +++ b/src/valid/expression.rs @@ -77,6 +77,26 @@ pub enum ExpressionError { InvalidImageOtherIndexType(Handle), #[error("Image coordinate index type of {1:?} does not match dimension {0:?}")] InvalidImageCoordinateType(crate::ImageDimension, Handle), + #[error("Comparison sampling mismatch: image has class {image:?}, but the sampler is comparison={sampler}, and the reference was provided={has_ref}")] + ComparisonSamplingMismatch { + image: crate::ImageClass, + sampler: bool, + has_ref: bool, + }, + #[error("Sample offset constant {1:?} doesn't match the image dimension {0:?}")] + InvalidSampleOffset(crate::ImageDimension, Handle), + #[error("Depth reference {0:?} is not a scalar float")] + InvalidDepthReference(Handle), + #[error("Sample level is not compatible with the image dimension {0:?}")] + InvalidSampleLevel(crate::ImageDimension), + #[error("Sample level (exact) type {0:?} is not a scalar float")] + InvalidSampleLevelExactType(Handle), + #[error("Sample level (bias) type {0:?} is not a scalar float")] + InvalidSampleLevelBiasType(Handle), + #[error("Sample level (gradient) of {1:?} doesn't match the image dimension {0:?}")] + InvalidSampleLevelGradientType(crate::ImageDimension, Handle), + #[error("Unable to cast")] + InvalidCastArgument, } struct ExpressionTypeResolver<'a> { @@ -327,7 +347,6 @@ impl super::Validator { } ShaderStages::all() } - #[allow(unused)] E::ImageSample { image, sampler, @@ -336,7 +355,176 @@ impl super::Validator { offset, level, depth_ref, - } => ShaderStages::all(), + } => { + let image_var = match function.expressions[image] { + crate::Expression::GlobalVariable(var_handle) => { + &module.global_variables[var_handle] + } + _ => return Err(ExpressionError::ExpectedGlobalVariable), + }; + let sampler_var = match function.expressions[sampler] { + crate::Expression::GlobalVariable(var_handle) => { + &module.global_variables[var_handle] + } + _ => return Err(ExpressionError::ExpectedGlobalVariable), + }; + let comparison = match module.types[sampler_var.ty].inner { + Ti::Sampler { comparison } => comparison, + _ => return Err(ExpressionError::ExpectedSamplerType(sampler_var.ty)), + }; + + let (class, dim) = match module.types[image_var.ty].inner { + Ti::Image { + //TODO: should we check that this is Float-only? + class, + arrayed, + dim, + } => { + let image_depth = match class { + crate::ImageClass::Sampled { + kind: _, + multi: false, + } => false, + crate::ImageClass::Depth => true, + _ => return Err(ExpressionError::InvalidImageClass(class)), + }; + if comparison != depth_ref.is_some() || (comparison && !image_depth) { + return Err(ExpressionError::ComparisonSamplingMismatch { + image: class, + sampler: comparison, + has_ref: depth_ref.is_some(), + }); + } + if arrayed != array_index.is_some() { + return Err(ExpressionError::InvalidImageArrayIndex); + } + if let Some(expr) = array_index { + match *resolver.resolve(expr)? { + Ti::Scalar { + kind: crate::ScalarKind::Sint, + width: _, + } => {} + _ => return Err(ExpressionError::InvalidImageArrayIndexType(expr)), + } + } + (class, dim) + } + _ => return Err(ExpressionError::ExpectedImageType(image_var.ty)), + }; + + let num_components = match dim { + crate::ImageDimension::D1 => 1, + crate::ImageDimension::D2 => 2, + crate::ImageDimension::D3 | crate::ImageDimension::Cube => 3, + }; + match *resolver.resolve(coordinate)? { + Ti::Scalar { + kind: crate::ScalarKind::Float, + .. + } if num_components == 1 => {} + Ti::Vector { + size, + kind: crate::ScalarKind::Float, + .. + } if size as u32 == num_components => {} + _ => return Err(ExpressionError::InvalidImageCoordinateType(dim, coordinate)), + } + if let Some(const_handle) = offset { + let good = match module.constants[const_handle].inner { + crate::ConstantInner::Scalar { + width: _, + value: crate::ScalarValue::Sint(_), + } => num_components == 1, + crate::ConstantInner::Scalar { .. } => false, + crate::ConstantInner::Composite { ty, .. } => { + match module.types[ty].inner { + Ti::Vector { + size, + kind: crate::ScalarKind::Float, + .. + } => size as u32 == num_components, + _ => false, + } + } + }; + if !good { + return Err(ExpressionError::InvalidSampleOffset(dim, const_handle)); + } + } + + if let Some(expr) = depth_ref { + match *resolver.resolve(expr)? { + Ti::Scalar { + kind: crate::ScalarKind::Float, + .. + } => {} + _ => return Err(ExpressionError::InvalidDepthReference(expr)), + } + } + let can_level = match class { + crate::ImageClass::Sampled { multi, .. } => !multi, + crate::ImageClass::Storage { .. } => false, + crate::ImageClass::Depth { .. } => true, + }; + + match level { + // require `can_level` here? + crate::SampleLevel::Auto => ShaderStages::FRAGMENT, + crate::SampleLevel::Zero => ShaderStages::all(), + crate::SampleLevel::Exact(expr) if can_level => { + match *resolver.resolve(expr)? { + Ti::Scalar { + kind: crate::ScalarKind::Float, + .. + } => {} + _ => return Err(ExpressionError::InvalidSampleLevelExactType(expr)), + } + ShaderStages::all() + } + crate::SampleLevel::Bias(expr) => { + match *resolver.resolve(expr)? { + Ti::Scalar { + kind: crate::ScalarKind::Float, + .. + } => {} + _ => return Err(ExpressionError::InvalidSampleLevelBiasType(expr)), + } + ShaderStages::all() + } + crate::SampleLevel::Gradient { x, y } => { + match *resolver.resolve(x)? { + Ti::Scalar { + kind: crate::ScalarKind::Float, + .. + } if num_components == 1 => {} + Ti::Vector { + size, + kind: crate::ScalarKind::Float, + .. + } if size as u32 == num_components => {} + _ => { + return Err(ExpressionError::InvalidSampleLevelGradientType(dim, x)) + } + } + match *resolver.resolve(y)? { + Ti::Scalar { + kind: crate::ScalarKind::Float, + .. + } if num_components == 1 => {} + Ti::Vector { + size, + kind: crate::ScalarKind::Float, + .. + } if size as u32 == num_components => {} + _ => { + return Err(ExpressionError::InvalidSampleLevelGradientType(dim, y)) + } + } + ShaderStages::all() + } + _ => return Err(ExpressionError::InvalidSampleLevel(dim)), + } + } E::ImageLoad { image, coordinate, @@ -683,12 +871,22 @@ impl super::Validator { arg1, arg2, } => ShaderStages::all(), - #[allow(unused)] E::As { expr, kind, convert, - } => ShaderStages::all(), + } => { + let prev_kind = resolver + .resolve(expr)? + .scalar_kind() + .ok_or(ExpressionError::InvalidCastArgument)?; + if !convert && prev_kind == crate::ScalarKind::Bool + || kind == crate::ScalarKind::Bool + { + return Err(ExpressionError::InvalidCastArgument); + } + ShaderStages::all() + } E::Call(function) => other_infos[function.index()].available_stages, E::ArrayLength(expr) => match *resolver.resolve(expr)? { Ti::Array { .. } => ShaderStages::all(),