From bde43cbec6e502d60cc3cf55b9e83d9fd33c9349 Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Thu, 11 Feb 2021 01:24:40 -0500 Subject: [PATCH] First uniform control flow tests --- src/lib.rs | 1 + src/proc/analyzer.rs | 104 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 105 insertions(+) diff --git a/src/lib.rs b/src/lib.rs index abad73cb4e..d569af58ab 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -641,6 +641,7 @@ pub enum ImageQuery { /// An expression that can be evaluated to obtain a value. #[derive(Clone, Debug)] +#[cfg_attr(test, derive(PartialEq))] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] pub enum Expression { diff --git a/src/proc/analyzer.rs b/src/proc/analyzer.rs index fcf968030f..13d9d36918 100644 --- a/src/proc/analyzer.rs +++ b/src/proc/analyzer.rs @@ -39,6 +39,7 @@ pub struct FunctionInfo { } #[derive(Clone, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] pub enum AnalysisError { #[error("Expression {0:?} is not a global variable!")] ExpectedGlobalVariable(crate::Expression), @@ -348,3 +349,106 @@ impl Analysis { Ok(this) } } + +#[test] +fn uniform_control_flow() { + use crate::{Expression as E, Statement as S}; + + let mut constant_arena = Arena::new(); + let constant = constant_arena.append(crate::Constant { + name: None, + specialization: None, + inner: crate::ConstantInner::Scalar { + width: 4, + value: crate::ScalarValue::Uint(0), + }, + }); + let mut type_arena = Arena::new(); + let ty = type_arena.append(crate::Type { + name: None, + inner: crate::TypeInner::Scalar { + kind: crate::ScalarKind::Float, + width: 4, + }, + }); + let mut global_var_arena = Arena::new(); + let non_uniform_global = global_var_arena.append(crate::GlobalVariable { + name: None, + init: None, + ty, + binding: Some(crate::Binding::BuiltIn(crate::BuiltIn::VertexIndex)), + class: crate::StorageClass::Input, + interpolation: None, + storage_access: crate::StorageAccess::empty(), + }); + let uniform_global = global_var_arena.append(crate::GlobalVariable { + name: None, + init: None, + ty, + binding: Some(crate::Binding::Location(0)), + class: crate::StorageClass::Input, + interpolation: Some(crate::Interpolation::Flat), + storage_access: crate::StorageAccess::empty(), + }); + + let mut expressions = Arena::new(); + let constant_expr = expressions.append(E::Constant(constant)); + let derivative_expr = expressions.append(E::Derivative { + axis: crate::DerivativeAxis::X, + expr: constant_expr, + }); + let non_uniform_global_expr = expressions.append(E::GlobalVariable(non_uniform_global)); + let uniform_global_expr = expressions.append(E::GlobalVariable(uniform_global)); + + let mut info = FunctionInfo { + control_flags: ControlFlags::empty(), + sampling_set: crate::FastHashSet::default(), + expressions: vec![ExpressionInfo::default(); expressions.len()].into_boxed_slice(), + }; + for (handle, _) in expressions.iter() { + info.process_expression(handle, &expressions, &global_var_arena, &[]) + .unwrap(); + } + assert_eq!(info.expressions[non_uniform_global.index()].ref_count, 1); + assert_eq!(info.expressions[uniform_global_expr.index()].ref_count, 0); + + let stmt_if_uniform = S::If { + condition: uniform_global_expr, + accept: Vec::new(), + reject: vec![S::Store { + pointer: constant_expr, + value: derivative_expr, + }], + }; + assert_eq!( + info.process_block(&[stmt_if_uniform], &[], true), + Ok(ControlFlags::REQUIRE_UNIFORM), + ); + assert_eq!(info.expressions[constant_expr.index()].ref_count, 2); + + let stmt_if_non_uniform = S::If { + condition: non_uniform_global_expr, + accept: vec![S::Store { + pointer: constant_expr, + value: derivative_expr, + }], + reject: Vec::new(), + }; + assert_eq!( + info.process_block(&[stmt_if_non_uniform], &[], true), + Err(AnalysisError::NonUniformControlFlow), + ); + assert_eq!(info.expressions[derivative_expr.index()].ref_count, 2); + + let stmt_return_non_uniform = S::Return { + value: Some(non_uniform_global_expr), + }; + assert_eq!( + info.process_block(&[stmt_return_non_uniform], &[], false), + Ok(ControlFlags::NON_UNIFORM_RESULT | ControlFlags::MAY_EXIT), + ); + assert_eq!( + info.expressions[non_uniform_global_expr.index()].ref_count, + 2 + ); +}