diff --git a/src/proc/validator.rs b/src/proc/validator.rs index 618a79061e..c018505f2c 100644 --- a/src/proc/validator.rs +++ b/src/proc/validator.rs @@ -2,7 +2,10 @@ use super::{ analyzer::{Analysis, AnalysisError, FunctionInfo, GlobalUse}, typifier::{ResolveContext, Typifier, TypifyError}, }; -use crate::arena::{Arena, Handle}; +use crate::{ + arena::{Arena, Handle}, + FastHashSet, +}; use bit_set::BitSet; use thiserror::Error; @@ -22,6 +25,45 @@ bitflags::bitflags! { } } +bitflags::bitflags! { + #[repr(transparent)] + pub struct BlockFlags: u8 { + /// The control can jump out of this block. + const CAN_JUMP = 0x1; + /// The control is in a loop, can break and continue. + const IN_LOOP = 0x2; + } +} + +struct BlockContext<'a> { + flags: BlockFlags, + expressions: &'a Arena, + types: &'a Arena, + functions: &'a Arena, + return_type: Option>, +} + +impl<'a> BlockContext<'a> { + fn with_flags(&self, flags: BlockFlags) -> Self { + BlockContext { + flags, + expressions: self.expressions, + types: self.types, + functions: self.functions, + return_type: self.return_type, + } + } + + fn get_expression( + &self, + handle: Handle, + ) -> Result<&'a crate::Expression, FunctionError> { + self.expressions + .try_get(handle) + .ok_or(FunctionError::InvalidExpression(handle)) + } +} + #[derive(Debug)] pub struct Validator { //Note: this is a bit tricky: some of the front-ends as well as backends @@ -31,6 +73,7 @@ pub struct Validator { location_in_mask: BitSet, location_out_mask: BitSet, bind_group_masks: Vec, + select_cases: FastHashSet, } #[derive(Clone, Debug, Error)] @@ -95,8 +138,6 @@ pub enum LocalVariableError { pub enum FunctionError { #[error(transparent)] Resolve(#[from] TypifyError), - #[error("There are instructions after `return`/`break`/`continue`")] - InvalidControlFlowExitTail, #[error("Local variable {handle:?} '{name}' is invalid")] LocalVariable { handle: Handle, @@ -106,6 +147,50 @@ pub enum FunctionError { }, #[error("Argument '{name}' at index {index} has a type that can't be passed into functions.")] InvalidArgumentType { index: usize, name: String }, + #[error("There are instructions after `return`/`break`/`continue`")] + InstructionsAfterReturn, + #[error("The `break`/`continue` is used outside of a loop context")] + BreakContinueOutsideOfLoop, + #[error("The `return` is called within a `continuing` block")] + InvalidReturnSpot, + #[error("The `return` value {0:?} does not match the function return value")] + InvalidReturnType(Option>), + #[error("The `if` condition {0:?} is not a boolean scalar")] + InvalidIfType(Handle), + #[error("The `switch` value {0:?} is not an integer scalar")] + InvalidSwitchType(Handle), + #[error("Multiple `switch` cases for {0} are present")] + ConflictingSwitchCase(i32), + #[error("The value {0:?} can not be stored")] + InvalidStoreValue(Handle), + #[error("Store of {value:?} into {pointer:?} doesn't have matching types")] + InvalidStore { + pointer: Handle, + value: Handle, + }, + #[error("The image array can't be indexed by {0:?}")] + InvalidArrayIndex(Handle), + #[error("The expression {0:?} is currupted")] + InvalidExpression(Handle), + #[error("The expression {0:?} is not an image")] + InvalidImage(Handle), + #[error("The called function {0:?} is invalid")] + InvalidCall(Handle), + #[error( + "The called function {function:?} requires {required} arguments, but {seen} are provided" + )] + InvalidCallArgumentCount { + function: Handle, + required: usize, + seen: usize, + }, + #[error("The called function {function:?} argument {index} requires {required:?} type, but {seen:?} type is provided")] + InvalidCallArgumentType { + function: Handle, + index: usize, + required: Handle, + seen: Option>, + }, } #[derive(Clone, Debug, Error)] @@ -324,6 +409,7 @@ impl Validator { location_in_mask: BitSet::new(), location_out_mask: BitSet::new(), bind_group_masks: Vec::new(), + select_cases: FastHashSet::default(), } } @@ -572,6 +658,165 @@ impl Validator { Ok(()) } + fn validate_block( + &mut self, + statements: &[crate::Statement], + context: &BlockContext, + ) -> Result<(), FunctionError> { + use crate::{Statement as S, TypeInner as Ti}; + //TODO: handle the cases of totally invalid expression handles + let mut finished = false; + for statement in statements { + if finished { + return Err(FunctionError::InstructionsAfterReturn); + } + match *statement { + S::Block(ref block) => self.validate_block(block, context)?, + S::If { + condition, + ref accept, + ref reject, + } => { + match *self.typifier.get(condition, context.types) { + Ti::Scalar { + kind: crate::ScalarKind::Bool, + width: _, + } => {} + _ => return Err(FunctionError::InvalidIfType(condition)), + } + self.validate_block(accept, context)?; + self.validate_block(reject, context)?; + } + S::Switch { + selector, + ref cases, + ref default, + } => { + match *self.typifier.get(selector, context.types) { + Ti::Scalar { + kind: crate::ScalarKind::Sint, + width: _, + } => {} + _ => return Err(FunctionError::InvalidSwitchType(selector)), + } + self.select_cases.clear(); + for case in cases { + if !self.select_cases.insert(case.value) { + return Err(FunctionError::ConflictingSwitchCase(case.value)); + } + } + for case in cases { + self.validate_block(&case.body, context)?; + } + self.validate_block(default, context)?; + } + S::Loop { + ref body, + ref continuing, + } => { + self.validate_block( + body, + &context.with_flags(BlockFlags::CAN_JUMP | BlockFlags::IN_LOOP), + )?; + self.validate_block(continuing, &context.with_flags(BlockFlags::empty()))?; + } + S::Break | S::Continue => { + if !context.flags.contains(BlockFlags::IN_LOOP) { + return Err(FunctionError::BreakContinueOutsideOfLoop); + } + finished = true; + } + S::Return { value } => { + if !context.flags.contains(BlockFlags::CAN_JUMP) { + return Err(FunctionError::InvalidReturnSpot); + } + let value_ty = value.map(|expr| self.typifier.get(expr, context.types)); + let expected_ty = context.return_type.map(|ty| &context.types[ty].inner); + if value_ty != expected_ty { + log::error!( + "Returning {:?} where {:?} is expected", + value_ty, + expected_ty + ); + return Err(FunctionError::InvalidReturnType(value)); + } + finished = true; + } + S::Kill => { + finished = true; + } + S::Store { pointer, value } => { + let value_ty = self.typifier.get(value, context.types); + match *value_ty { + Ti::Image { .. } | Ti::Sampler { .. } => { + return Err(FunctionError::InvalidStoreValue(value)); + } + _ => {} + } + if self.typifier.get(pointer, context.types) != value_ty { + return Err(FunctionError::InvalidStore { pointer, value }); + } + //TODO: validate that the `pointer` reaches a variable through a sequence of accessors + } + S::ImageStore { + image, + coordinate: _, + array_index, + value, + } => { + let _expected_coordinate_ty = match *context.get_expression(image)? { + crate::Expression::GlobalVariable(_var_handle) => (), //TODO + _ => return Err(FunctionError::InvalidImage(image)), + }; + let value_ty = self.typifier.get(value, context.types); + match *value_ty { + Ti::Scalar { .. } | Ti::Vector { .. } => {} + _ => { + return Err(FunctionError::InvalidStoreValue(value)); + } + } + if let Some(expr) = array_index { + match *self.typifier.get(expr, context.types) { + Ti::Scalar { + kind: crate::ScalarKind::Sint, + width: _, + } => (), + _ => return Err(FunctionError::InvalidArrayIndex(expr)), + } + } + } + S::Call { + function, + ref arguments, + } => { + let fun = context + .functions + .try_get(function) + .ok_or(FunctionError::InvalidCall(function))?; + if fun.arguments.len() != arguments.len() { + return Err(FunctionError::InvalidCallArgumentCount { + function, + required: fun.arguments.len(), + seen: arguments.len(), + }); + } + for (index, (arg, &expr)) in fun.arguments.iter().zip(arguments).enumerate() { + let ty = self.typifier.get_handle(expr); + if ty != Ok(arg.ty) { + return Err(FunctionError::InvalidCallArgumentType { + function, + index, + required: arg.ty, + seen: ty.ok(), + }); + } + } + } + } + } + Ok(()) + } + fn validate_function( &mut self, fun: &crate::Function, @@ -598,18 +843,24 @@ impl Validator { } for (index, argument) in fun.arguments.iter().enumerate() { - match module.types[argument.ty].inner { - crate::TypeInner::Image { .. } | crate::TypeInner::Sampler { .. } => { - return Err(FunctionError::InvalidArgumentType { - index, - name: argument.name.clone().unwrap_or_default(), - }) - } - _ => (), + if !self.type_flags[argument.ty.index()].contains(TypeFlags::DATA) { + return Err(FunctionError::InvalidArgumentType { + index, + name: argument.name.clone().unwrap_or_default(), + }); } } - Ok(()) + self.validate_block( + &fun.body, + &BlockContext { + flags: BlockFlags::CAN_JUMP, + expressions: &fun.expressions, + types: &module.types, + functions: &module.functions, + return_type: fun.return_type, + }, + ) } fn validate_entry_point(