diff --git a/src/valid/analyzer.rs b/src/valid/analyzer.rs index aae5dcddef..032c4f37f2 100644 --- a/src/valid/analyzer.rs +++ b/src/valid/analyzer.rs @@ -159,6 +159,24 @@ impl ExpressionInfo { } } +#[derive(Debug, Clone, Copy)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +enum GlobalOrArgument { + Global(Handle), + Argument(u32), +} + +#[derive(Debug, Clone)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +struct ArgumentSampling { + /// Whether this argument is used as an image or a sampler + image: bool, + /// The other global or argument used in the sampling + uses: Vec, +} + #[derive(Debug)] #[cfg_attr(feature = "serialize", derive(serde::Serialize))] #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] @@ -181,6 +199,10 @@ pub struct FunctionInfo { /// /// Each item corresponds to an expression in the function. expressions: Box<[ExpressionInfo]>, + /// Vector with information of wether or not a functin argument is used for sampling + /// + /// Each item corresponds to a function argument + argument_sampling: Box<[Option]>, } impl FunctionInfo { @@ -273,21 +295,101 @@ impl FunctionInfo { } /// Inherit information from a called function. - fn process_call(&mut self, info: &Self) -> FunctionUniformity { + fn process_call( + &mut self, + info: &Self, + arguments: &[Handle], + expression_arena: &Arena, + ) -> Result { for key in info.sampling_set.iter() { self.sampling_set.insert(key.clone()); } + for (i, arg_sampling) in info + .argument_sampling + .iter() + .enumerate() + .filter_map(|(i, s)| Some((i, s.as_ref()?))) + { + let handle = arguments[i]; + let arg_storage = match expression_arena[handle] { + crate::Expression::GlobalVariable(var) => GlobalOrArgument::Global(var), + crate::Expression::FunctionArgument(i) => GlobalOrArgument::Argument(i), + _ => { + return Err(FunctionError::Expression { + handle, + error: ExpressionError::ExpectedGlobalVariable, + }) + } + }; + + for other in arg_sampling.uses.iter() { + let other_storage = match *other { + GlobalOrArgument::Global(var) => GlobalOrArgument::Global(var), + GlobalOrArgument::Argument(i) => { + let other_handle = arguments[i as usize]; + match expression_arena[other_handle] { + crate::Expression::GlobalVariable(var) => GlobalOrArgument::Global(var), + crate::Expression::FunctionArgument(i) => GlobalOrArgument::Argument(i), + _ => { + return Err(FunctionError::Expression { + handle, + error: ExpressionError::ExpectedGlobalVariable, + }) + } + } + } + }; + + match (arg_storage, other_storage) { + (GlobalOrArgument::Global(arg), GlobalOrArgument::Global(other)) => { + if arg_sampling.image { + self.sampling_set.insert(SamplingKey { + image: arg, + sampler: other, + }); + } else { + self.sampling_set.insert(SamplingKey { + image: other, + sampler: arg, + }); + } + } + (GlobalOrArgument::Argument(i), _) => { + let sampling = + self.argument_sampling[i as usize].get_or_insert_with(|| { + ArgumentSampling { + image: arg_sampling.image, + uses: Vec::with_capacity(1), + } + }); + + sampling.uses.push(other_storage) + } + (_, GlobalOrArgument::Argument(i)) => { + let sampling = + self.argument_sampling[i as usize].get_or_insert_with(|| { + ArgumentSampling { + image: !arg_sampling.image, + uses: Vec::with_capacity(1), + } + }); + + sampling.uses.push(arg_storage) + } + } + } + } for (mine, other) in self.global_uses.iter_mut().zip(info.global_uses.iter()) { *mine |= *other; } - FunctionUniformity { + Ok(FunctionUniformity { result: info.uniformity.clone(), exit: if info.may_kill { ExitFlags::MAY_KILL } else { ExitFlags::empty() }, - } + }) } /// Computes the expression info and stores it in `self.expressions`. @@ -397,16 +499,45 @@ impl FunctionInfo { level, depth_ref, } => { - self.sampling_set.insert(SamplingKey { - image: match expression_arena[image] { - crate::Expression::GlobalVariable(var) => var, - _ => return Err(ExpressionError::ExpectedGlobalVariable), - }, - sampler: match expression_arena[sampler] { - crate::Expression::GlobalVariable(var) => var, - _ => return Err(ExpressionError::ExpectedGlobalVariable), - }, - }); + let image_storage = match expression_arena[image] { + crate::Expression::GlobalVariable(var) => GlobalOrArgument::Global(var), + crate::Expression::FunctionArgument(i) => GlobalOrArgument::Argument(i), + _ => return Err(ExpressionError::ExpectedGlobalVariable), + }; + let sampler_storage = match expression_arena[sampler] { + crate::Expression::GlobalVariable(var) => GlobalOrArgument::Global(var), + crate::Expression::FunctionArgument(i) => GlobalOrArgument::Argument(i), + _ => return Err(ExpressionError::ExpectedGlobalVariable), + }; + + match (image_storage, sampler_storage) { + (GlobalOrArgument::Global(image), GlobalOrArgument::Global(sampler)) => { + self.sampling_set.insert(SamplingKey { image, sampler }); + } + (GlobalOrArgument::Argument(i), _) => { + let sampling = + self.argument_sampling[i as usize].get_or_insert_with(|| { + ArgumentSampling { + image: true, + uses: Vec::with_capacity(1), + } + }); + + sampling.uses.push(sampler_storage) + } + (_, GlobalOrArgument::Argument(i)) => { + let sampling = + self.argument_sampling[i as usize].get_or_insert_with(|| { + ArgumentSampling { + image: false, + uses: Vec::with_capacity(1), + } + }); + + sampling.uses.push(image_storage) + } + } + // "nur" == "Non-Uniform Result" let array_nur = array_index.and_then(|h| self.add_ref(h)); let level_nur = match level { @@ -501,10 +632,11 @@ impl FunctionInfo { requirements: UniformityRequirements::empty(), }, E::Call(function) => { - let fun = other_functions + let info = other_functions .get(function.index()) .ok_or(ExpressionError::CallToUndeclaredFunction(function))?; - self.process_call(fun).result + + info.uniformity.clone() } E::ArrayLength(expr) => Uniformity { non_uniform_result: self.add_ref_impl(expr, GlobalUse::QUERY), @@ -537,6 +669,7 @@ impl FunctionInfo { statements: &[crate::Statement], other_functions: &[FunctionInfo], mut disruptor: Option, + expression_arena: &Arena, ) -> Result { use crate::Statement as S; @@ -582,7 +715,9 @@ impl FunctionInfo { }, exit: ExitFlags::empty(), }, - S::Block(ref b) => self.process_block(b, other_functions, disruptor)?, + S::Block(ref b) => { + self.process_block(b, other_functions, disruptor, expression_arena)? + } S::If { condition, ref accept, @@ -591,10 +726,18 @@ impl FunctionInfo { let condition_nur = self.add_ref(condition); let branch_disruptor = disruptor.or(condition_nur.map(UniformityDisruptor::Expression)); - let accept_uniformity = - self.process_block(accept, other_functions, branch_disruptor)?; - let reject_uniformity = - self.process_block(reject, other_functions, branch_disruptor)?; + let accept_uniformity = self.process_block( + accept, + other_functions, + branch_disruptor, + expression_arena, + )?; + let reject_uniformity = self.process_block( + reject, + other_functions, + branch_disruptor, + expression_arena, + )?; accept_uniformity | reject_uniformity } S::Switch { @@ -608,8 +751,12 @@ impl FunctionInfo { let mut uniformity = FunctionUniformity::new(); let mut case_disruptor = branch_disruptor; for case in cases.iter() { - let case_uniformity = - self.process_block(&case.body, other_functions, case_disruptor)?; + let case_uniformity = self.process_block( + &case.body, + other_functions, + case_disruptor, + expression_arena, + )?; case_disruptor = if case.fall_through { case_disruptor.or(case_uniformity.exit_disruptor()) } else { @@ -618,18 +765,27 @@ impl FunctionInfo { uniformity = uniformity | case_uniformity; } // using the disruptor inherited from the last fall-through chain - let default_exit = - self.process_block(default, other_functions, case_disruptor)?; + let default_exit = self.process_block( + default, + other_functions, + case_disruptor, + expression_arena, + )?; uniformity | default_exit } S::Loop { ref body, ref continuing, } => { - let body_uniformity = self.process_block(body, other_functions, disruptor)?; + let body_uniformity = + self.process_block(body, other_functions, disruptor, expression_arena)?; let continuing_disruptor = disruptor.or(body_uniformity.exit_disruptor()); - let continuing_uniformity = - self.process_block(continuing, other_functions, continuing_disruptor)?; + let continuing_uniformity = self.process_block( + continuing, + other_functions, + continuing_disruptor, + expression_arena, + )?; body_uniformity | continuing_uniformity } S::Return { value } => FunctionUniformity { @@ -680,7 +836,7 @@ impl FunctionInfo { }, )?; //Note: the result is validated by the Validator, not here - self.process_call(info) + self.process_call(info, arguments, expression_arena)? } }; @@ -708,6 +864,7 @@ impl ModuleInfo { sampling_set: crate::FastHashSet::default(), global_uses: vec![GlobalUse::empty(); module.global_variables.len()].into_boxed_slice(), expressions: vec![ExpressionInfo::new(); fun.expressions.len()].into_boxed_slice(), + argument_sampling: vec![None; fun.arguments.len()].into_boxed_slice(), }; let resolve_context = ResolveContext { constants: &module.constants, @@ -730,7 +887,7 @@ impl ModuleInfo { } } - let uniformity = info.process_block(&fun.body, &self.functions, None)?; + let uniformity = info.process_block(&fun.body, &self.functions, None, &fun.expressions)?; info.uniformity = uniformity.result; info.may_kill = uniformity.exit.contains(ExitFlags::MAY_KILL); @@ -812,6 +969,7 @@ fn uniform_control_flow() { sampling_set: crate::FastHashSet::default(), global_uses: vec![GlobalUse::empty(); global_var_arena.len()].into_boxed_slice(), expressions: vec![ExpressionInfo::new(); expressions.len()].into_boxed_slice(), + argument_sampling: vec![None; 0].into_boxed_slice(), }; let resolve_context = ResolveContext { constants: &constant_arena, @@ -845,7 +1003,7 @@ fn uniform_control_flow() { ], }; assert_eq!( - info.process_block(&[stmt_emit1, stmt_if_uniform], &[], None), + info.process_block(&[stmt_emit1, stmt_if_uniform], &[], None, &expressions), Ok(FunctionUniformity { result: Uniformity { non_uniform_result: None, @@ -870,7 +1028,7 @@ fn uniform_control_flow() { reject: Vec::new(), }; assert_eq!( - info.process_block(&[stmt_emit2, stmt_if_non_uniform], &[], None), + info.process_block(&[stmt_emit2, stmt_if_non_uniform], &[], None, &expressions), Err(FunctionError::NonUniformControlFlow( UniformityRequirements::DERIVATIVE, derivative_expr, @@ -888,7 +1046,8 @@ fn uniform_control_flow() { info.process_block( &[stmt_emit3, stmt_return_non_uniform], &[], - Some(UniformityDisruptor::Return) + Some(UniformityDisruptor::Return), + &expressions ), Ok(FunctionUniformity { result: Uniformity { @@ -914,7 +1073,8 @@ fn uniform_control_flow() { info.process_block( &[stmt_emit4, stmt_assign, stmt_kill, stmt_return_pointer], &[], - Some(UniformityDisruptor::Discard) + Some(UniformityDisruptor::Discard), + &expressions ), Ok(FunctionUniformity { result: Uniformity { diff --git a/src/valid/expression.rs b/src/valid/expression.rs index a4035fede9..409ca82eb0 100644 --- a/src/valid/expression.rs +++ b/src/valid/expression.rs @@ -339,24 +339,26 @@ impl super::Validator { depth_ref, } => { // check the validity of expressions - let image_var = match function.expressions[image] { + let image_ty = match function.expressions[image] { crate::Expression::GlobalVariable(var_handle) => { - &module.global_variables[var_handle] + module.global_variables[var_handle].ty } + crate::Expression::FunctionArgument(i) => function.arguments[i as usize].ty, _ => return Err(ExpressionError::ExpectedGlobalVariable), }; - let sampler_var = match function.expressions[sampler] { + let sampler_ty = match function.expressions[sampler] { crate::Expression::GlobalVariable(var_handle) => { - &module.global_variables[var_handle] + module.global_variables[var_handle].ty } + crate::Expression::FunctionArgument(i) => function.arguments[i as usize].ty, _ => return Err(ExpressionError::ExpectedGlobalVariable), }; - let comparison = match module.types[sampler_var.ty].inner { + let comparison = match module.types[sampler_ty].inner { Ti::Sampler { comparison } => comparison, - _ => return Err(ExpressionError::ExpectedSamplerType(sampler_var.ty)), + _ => return Err(ExpressionError::ExpectedSamplerType(sampler_ty)), }; - let (class, dim) = match module.types[image_var.ty].inner { + let (class, dim) = match module.types[image_ty].inner { Ti::Image { class, arrayed, @@ -377,7 +379,7 @@ impl super::Validator { } (class, dim) } - _ => return Err(ExpressionError::ExpectedImageType(image_var.ty)), + _ => return Err(ExpressionError::ExpectedImageType(image_ty)), }; // check sampling and comparison properties @@ -515,13 +517,14 @@ impl super::Validator { array_index, index, } => { - let var = match function.expressions[image] { + let ty = match function.expressions[image] { crate::Expression::GlobalVariable(var_handle) => { - &module.global_variables[var_handle] + module.global_variables[var_handle].ty } + crate::Expression::FunctionArgument(i) => function.arguments[i as usize].ty, _ => return Err(ExpressionError::ExpectedGlobalVariable), }; - match module.types[var.ty].inner { + match module.types[ty].inner { Ti::Image { class, arrayed, @@ -564,18 +567,19 @@ impl super::Validator { } } } - _ => return Err(ExpressionError::ExpectedImageType(var.ty)), + _ => return Err(ExpressionError::ExpectedImageType(ty)), } ShaderStages::all() } E::ImageQuery { image, query } => { - let var = match function.expressions[image] { + let ty = match function.expressions[image] { crate::Expression::GlobalVariable(var_handle) => { - &module.global_variables[var_handle] + module.global_variables[var_handle].ty } + crate::Expression::FunctionArgument(i) => function.arguments[i as usize].ty, _ => return Err(ExpressionError::ExpectedGlobalVariable), }; - match module.types[var.ty].inner { + match module.types[ty].inner { Ti::Image { class, arrayed, .. } => { let can_level = match class { crate::ImageClass::Sampled { multi, .. } => !multi, @@ -593,7 +597,7 @@ impl super::Validator { return Err(ExpressionError::InvalidImageClass(class)); } } - _ => return Err(ExpressionError::ExpectedImageType(var.ty)), + _ => return Err(ExpressionError::ExpectedImageType(ty)), } ShaderStages::all() } diff --git a/src/valid/function.rs b/src/valid/function.rs index bb5d49c977..b05c0e3917 100644 --- a/src/valid/function.rs +++ b/src/valid/function.rs @@ -578,7 +578,7 @@ impl super::Validator { for (index, argument) in fun.arguments.iter().enumerate() { if !self.types[argument.ty.index()] .flags - .contains(TypeFlags::DATA | TypeFlags::SIZED) + .contains(TypeFlags::ARGUMENT) { return Err(FunctionError::InvalidArgumentType { index, diff --git a/src/valid/type.rs b/src/valid/type.rs index 474c5dc0af..e1f25af7bd 100644 --- a/src/valid/type.rs +++ b/src/valid/type.rs @@ -46,6 +46,9 @@ bitflags::bitflags! { /// This is a top-level host-shareable type. const TOP_LEVEL = 0x10; + + /// This type can be passed as a function argument. + const ARGUMENT = 0x20; } } @@ -192,7 +195,8 @@ impl super::Validator { TypeFlags::DATA | TypeFlags::SIZED | TypeFlags::INTERFACE - | TypeFlags::HOST_SHARED, + | TypeFlags::HOST_SHARED + | TypeFlags::ARGUMENT, width as u32, ) } @@ -205,7 +209,8 @@ impl super::Validator { TypeFlags::DATA | TypeFlags::SIZED | TypeFlags::INTERFACE - | TypeFlags::HOST_SHARED, + | TypeFlags::HOST_SHARED + | TypeFlags::ARGUMENT, count * (width as u32), ) } @@ -222,7 +227,8 @@ impl super::Validator { TypeFlags::DATA | TypeFlags::SIZED | TypeFlags::INTERFACE - | TypeFlags::HOST_SHARED, + | TypeFlags::HOST_SHARED + | TypeFlags::ARGUMENT, count * (width as u32), ) } @@ -239,9 +245,9 @@ impl super::Validator { // `DATA`. let base_info = &self.types[base.index()]; let data_flag = if base_info.flags.contains(TypeFlags::SIZED) { - TypeFlags::DATA + TypeFlags::DATA | TypeFlags::ARGUMENT } else if let crate::TypeInner::Struct { .. } = types[base].inner { - TypeFlags::DATA + TypeFlags::DATA | TypeFlags::ARGUMENT } else { TypeFlags::empty() }; @@ -350,7 +356,7 @@ impl super::Validator { return Err(TypeError::NonPositiveArrayLength(const_handle)); } - TypeFlags::SIZED + TypeFlags::SIZED | TypeFlags::ARGUMENT } crate::ArraySize::Dynamic => { // Non-SIZED types may only appear as the last element of a structure. @@ -376,7 +382,8 @@ impl super::Validator { TypeFlags::DATA | TypeFlags::SIZED | TypeFlags::HOST_SHARED - | TypeFlags::INTERFACE, + | TypeFlags::INTERFACE + | TypeFlags::ARGUMENT, 1, ); let mut min_offset = 0; @@ -460,7 +467,7 @@ impl super::Validator { ti } - Ti::Image { .. } | Ti::Sampler { .. } => TypeInfo::new(TypeFlags::empty(), 0), + Ti::Image { .. } | Ti::Sampler { .. } => TypeInfo::new(TypeFlags::ARGUMENT, 0), }) } } diff --git a/tests/out/analysis/collatz.info.ron b/tests/out/analysis/collatz.info.ron index edf6a5a5a4..e46776e9de 100644 --- a/tests/out/analysis/collatz.info.ron +++ b/tests/out/analysis/collatz.info.ron @@ -336,6 +336,9 @@ ty: Handle(1), ), ], + argument_sampling: [ + None, + ], ), ], entry_points: [ @@ -492,6 +495,9 @@ ty: Handle(1), ), ], + argument_sampling: [ + None, + ], ), ], ) \ No newline at end of file diff --git a/tests/out/analysis/shadow.info.ron b/tests/out/analysis/shadow.info.ron index 6d59c22e68..abaef43360 100644 --- a/tests/out/analysis/shadow.info.ron +++ b/tests/out/analysis/shadow.info.ron @@ -1060,6 +1060,10 @@ )), ), ], + argument_sampling: [ + None, + None, + ], ), ( flags: ( @@ -2748,6 +2752,7 @@ ty: Handle(4), ), ], + argument_sampling: [], ), ], entry_points: [ @@ -2871,6 +2876,10 @@ ty: Handle(4), ), ], + argument_sampling: [ + None, + None, + ], ), ], ) \ No newline at end of file