From 1d3f2bbdb1c27b167a584908fb562e7742650d65 Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Mon, 22 Mar 2021 00:45:11 -0400 Subject: [PATCH] Merge FunctionAnalysisError into FunctionError --- src/back/msl/writer.rs | 4 +- src/valid/analyzer.rs | 99 +++++++++-------------------------------- src/valid/expression.rs | 1 + src/valid/function.rs | 28 +++++++++--- src/valid/interface.rs | 11 ++--- src/valid/mod.rs | 56 ++++++++++++++--------- 6 files changed, 88 insertions(+), 111 deletions(-) diff --git a/src/back/msl/writer.rs b/src/back/msl/writer.rs index 5bb2b4369b..f9c8e7eb1d 100644 --- a/src/back/msl/writer.rs +++ b/src/back/msl/writer.rs @@ -1516,7 +1516,9 @@ fn test_stack_size() { }); let _ = module.functions.append(fun); // analyse the module - let info = ModuleInfo::new(&module, ValidationFlags::empty()).unwrap(); + let info = crate::valid::Validator::new(ValidationFlags::empty()) + .validate(&module) + .unwrap(); // process the module let mut writer = Writer::new(String::new()); writer.write(&module, &info, &Default::default()).unwrap(); diff --git a/src/valid/analyzer.rs b/src/valid/analyzer.rs index 520042bba9..20da97ad8d 100644 --- a/src/valid/analyzer.rs +++ b/src/valid/analyzer.rs @@ -6,7 +6,7 @@ Figures out the following properties: - expression reference counts !*/ -use super::ValidationFlags; +use super::{CallError, FunctionError, ModuleInfo, ValidationFlags}; use crate::arena::{Arena, Handle}; use std::ops; @@ -216,32 +216,6 @@ pub enum UniformityDisruptor { Discard, } -#[derive(Clone, Debug, thiserror::Error)] -#[cfg_attr(test, derive(PartialEq))] -pub enum FunctionAnalysisError { - #[error("Expression {0:?} is not a global variable!")] - ExpectedGlobalVariable(crate::Expression), - #[error("Called function {0:?} that hasn't been declared in the IR yet")] - ForwardCall(Handle), - #[error( - "Required uniformity of control flow for {0:?} in {1:?} is not fulfilled because of {2:?}" - )] - NonUniformControlFlow( - UniformityRequirements, - Handle, - UniformityDisruptor, - ), -} - -#[derive(Clone, Debug, thiserror::Error)] -#[cfg_attr(test, derive(PartialEq))] -pub enum AnalysisError { - #[error("Function {0:?} analysis failed")] - Function(Handle, #[source] FunctionAnalysisError), - #[error("Entry point {0:?}/'{1}' function analysis failed")] - EntryPoint(crate::ShaderStage, String, #[source] FunctionAnalysisError), -} - impl FunctionInfo { /// Adds a value-type reference to an expression. #[must_use] @@ -314,7 +288,7 @@ impl FunctionInfo { arguments: &[crate::FunctionArgument], global_var_arena: &Arena, other_functions: &[FunctionInfo], - ) -> Result<(), FunctionAnalysisError> { + ) -> Result<(), FunctionError> { use crate::{Expression as E, SampleLevel as Sl}; let mut assignable_global = None; @@ -404,17 +378,13 @@ impl FunctionInfo { image: match expression_arena[image] { crate::Expression::GlobalVariable(var) => var, ref other => { - return Err(FunctionAnalysisError::ExpectedGlobalVariable( - other.clone(), - )) + return Err(FunctionError::ExpectedGlobalVariable(other.clone())) } }, sampler: match expression_arena[sampler] { crate::Expression::GlobalVariable(var) => var, ref other => { - return Err(FunctionAnalysisError::ExpectedGlobalVariable( - other.clone(), - )) + return Err(FunctionError::ExpectedGlobalVariable(other.clone())) } }, }); @@ -512,9 +482,13 @@ impl FunctionInfo { requirements: UniformityRequirements::empty(), }, E::Call(function) => { - let fun = other_functions - .get(function.index()) - .ok_or(FunctionAnalysisError::ForwardCall(function))?; + let fun = + other_functions + .get(function.index()) + .ok_or(FunctionError::InvalidCall { + function, + error: CallError::ForwardDeclaredFunction, + })?; self.process_call(fun).result } E::ArrayLength(expr) => Uniformity { @@ -546,7 +520,7 @@ impl FunctionInfo { statements: &[crate::Statement], other_functions: &[FunctionInfo], mut disruptor: Option, - ) -> Result { + ) -> Result { use crate::Statement as S; let mut combined_uniformity = FunctionUniformity::new(); @@ -562,9 +536,7 @@ impl FunctionInfo { && !req.is_empty() { if let Some(cause) = disruptor { - return Err(FunctionAnalysisError::NonUniformControlFlow( - req, expr, cause, - )); + return Err(FunctionError::NonUniformControlFlow(req, expr, cause)); } } requirements |= req; @@ -670,9 +642,12 @@ impl FunctionInfo { for &argument in arguments { let _ = self.add_ref(argument); } - let info = other_functions - .get(function.index()) - .ok_or(FunctionAnalysisError::ForwardCall(function))?; + let info = other_functions.get(function.index()).ok_or( + FunctionError::InvalidCall { + function, + error: CallError::ForwardDeclaredFunction, + }, + )?; self.process_call(info) } }; @@ -684,22 +659,15 @@ impl FunctionInfo { } } -#[cfg_attr(feature = "serialize", derive(serde::Serialize))] -#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] -pub struct ModuleInfo { - functions: Vec, - entry_points: Vec, -} - impl ModuleInfo { /// Builds the `FunctionInfo` based on the function, and validates the /// uniform control flow if required by the expressions of this function. - fn process_function( + pub(super) fn process_function( &self, fun: &crate::Function, global_var_arena: &Arena, flags: ValidationFlags, - ) -> Result { + ) -> Result { let mut info = FunctionInfo { flags, uniformity: Uniformity::new(), @@ -726,29 +694,6 @@ impl ModuleInfo { Ok(info) } - /// Analyze a module and return the `ModuleInfo`, if successful. - pub fn new(module: &crate::Module, flags: ValidationFlags) -> Result { - let mut this = ModuleInfo { - functions: Vec::with_capacity(module.functions.len()), - entry_points: Vec::with_capacity(module.entry_points.len()), - }; - for (fun_handle, fun) in module.functions.iter() { - let info = this - .process_function(fun, &module.global_variables, flags) - .map_err(|source| AnalysisError::Function(fun_handle, source))?; - this.functions.push(info); - } - - for ep in module.entry_points.iter() { - let info = this - .process_function(&ep.function, &module.global_variables, flags) - .map_err(|source| AnalysisError::EntryPoint(ep.stage, ep.name.clone(), source))?; - this.entry_points.push(info); - } - - Ok(this) - } - pub fn get_entry_point(&self, index: usize) -> &FunctionInfo { &self.entry_points[index] } @@ -880,7 +825,7 @@ fn uniform_control_flow() { }; assert_eq!( info.process_block(&[stmt_emit2, stmt_if_non_uniform], &[], None), - Err(FunctionAnalysisError::NonUniformControlFlow( + Err(FunctionError::NonUniformControlFlow( UniformityRequirements::DERIVATIVE, derivative_expr, UniformityDisruptor::Expression(non_uniform_global_expr) diff --git a/src/valid/expression.rs b/src/valid/expression.rs index 119d1f4d82..0e460f7930 100644 --- a/src/valid/expression.rs +++ b/src/valid/expression.rs @@ -4,6 +4,7 @@ use crate::{ }; #[derive(Clone, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] pub enum ExpressionError { #[error("Doesn't exist")] DoesntExist, diff --git a/src/valid/function.rs b/src/valid/function.rs index 09ca6b44e8..770f70c851 100644 --- a/src/valid/function.rs +++ b/src/valid/function.rs @@ -1,10 +1,14 @@ -use super::{analyzer::FunctionInfo, ExpressionError, TypeFlags, ValidationFlags}; +use super::{ + analyzer::{FunctionInfo, UniformityDisruptor, UniformityRequirements}, + ExpressionError, ModuleInfo, TypeFlags, ValidationFlags, +}; use crate::{ arena::{Arena, Handle}, proc::{ResolveContext, TypifyError}, }; #[derive(Clone, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] pub enum CallError { #[error("Bad function")] InvalidFunction, @@ -36,12 +40,14 @@ pub enum CallError { } #[derive(Clone, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] pub enum LocalVariableError { #[error("Initializer doesn't match the variable type")] InitializerType, } #[derive(Clone, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] pub enum FunctionError { #[error(transparent)] Resolve(#[from] TypifyError), @@ -97,6 +103,16 @@ pub enum FunctionError { #[source] error: CallError, }, + #[error("Expression {0:?} is not a global variable!")] + ExpectedGlobalVariable(crate::Expression), + #[error( + "Required uniformity of control flow for {0:?} in {1:?} is not fulfilled because of {2:?}" + )] + NonUniformControlFlow( + UniformityRequirements, + Handle, + UniformityDisruptor, + ), } bitflags::bitflags! { @@ -464,9 +480,9 @@ impl super::Validator { pub(super) fn validate_function( &mut self, fun: &crate::Function, - _info: &FunctionInfo, module: &crate::Module, - ) -> Result<(), FunctionError> { + mod_info: &ModuleInfo, + ) -> Result { let resolve_ctx = ResolveContext { constants: &module.constants, global_vars: &module.global_variables, @@ -476,6 +492,7 @@ impl super::Validator { }; self.typifier .resolve_all(&fun.expressions, &module.types, &resolve_ctx)?; + let info = mod_info.process_function(fun, &module.global_variables, self.flags)?; for (var_handle, var) in fun.local_variables.iter() { self.validate_local_var(var, &module.types, &module.constants) @@ -511,9 +528,8 @@ impl super::Validator { } if self.flags.contains(ValidationFlags::BLOCKS) { - self.validate_block(&fun.body, &BlockContext::new(fun, module)) - } else { - Ok(()) + self.validate_block(&fun.body, &BlockContext::new(fun, module))?; } + Ok(info) } } diff --git a/src/valid/interface.rs b/src/valid/interface.rs index 2e5b275200..b6fe1a5a4b 100644 --- a/src/valid/interface.rs +++ b/src/valid/interface.rs @@ -1,6 +1,6 @@ use super::{ analyzer::{FunctionInfo, GlobalUse}, - Disalignment, FunctionError, TypeFlags, + Disalignment, FunctionError, ModuleInfo, TypeFlags, }; use crate::arena::{Arena, Handle}; @@ -352,9 +352,9 @@ impl super::Validator { pub(super) fn validate_entry_point( &mut self, ep: &crate::EntryPoint, - info: &FunctionInfo, module: &crate::Module, - ) -> Result<(), EntryPointError> { + mod_info: &ModuleInfo, + ) -> Result { if ep.early_depth_test.is_some() && ep.stage != crate::ShaderStage::Fragment { return Err(EntryPointError::UnexpectedEarlyDepthTest); } @@ -370,6 +370,8 @@ impl super::Validator { return Err(EntryPointError::UnexpectedWorkgroupSize); } + let info = self.validate_function(&ep.function, module, &mod_info)?; + self.location_mask.clear(); for (index, fa) in ep.function.arguments.iter().enumerate() { let ctx = VaryingContext { @@ -439,7 +441,6 @@ impl super::Validator { } } - self.validate_function(&ep.function, info, module)?; - Ok(()) + Ok(info) } } diff --git a/src/valid/mod.rs b/src/valid/mod.rs index 3d5092e76f..b3223adf8e 100644 --- a/src/valid/mod.rs +++ b/src/valid/mod.rs @@ -13,10 +13,7 @@ use bit_set::BitSet; //TODO: analyze the model at the same time as we validate it, // merge the corresponding matches over expressions and statements. -pub use analyzer::{ - AnalysisError, ExpressionInfo, FunctionInfo, GlobalUse, ModuleInfo, Uniformity, - UniformityRequirements, -}; +pub use analyzer::{ExpressionInfo, FunctionInfo, GlobalUse, Uniformity, UniformityRequirements}; pub use expression::ExpressionError; pub use function::{CallError, FunctionError, LocalVariableError}; pub use interface::{EntryPointError, GlobalVariableError, VaryingError}; @@ -33,6 +30,13 @@ bitflags::bitflags! { } } +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +pub struct ModuleInfo { + functions: Vec, + entry_points: Vec, +} + #[derive(Debug)] pub struct Validator { flags: ValidationFlags, @@ -94,8 +98,6 @@ pub enum ValidationError { #[source] error: EntryPointError, }, - #[error(transparent)] - Analysis(#[from] AnalysisError), #[error("Module is corrupted")] Corrupted, } @@ -176,8 +178,10 @@ impl Validator { pub fn validate(&mut self, module: &crate::Module) -> Result { self.reset_types(module.types.len()); - let mod_info = ModuleInfo::new(module, self.flags)?; - + let mut mod_info = ModuleInfo { + functions: Vec::with_capacity(module.functions.len()), + entry_points: Vec::with_capacity(module.entry_points.len()), + }; let layouter = Layouter::new(&module.types, &module.constants); for (handle, constant) in module.constants.iter() { @@ -211,16 +215,20 @@ impl Validator { } for (handle, fun) in module.functions.iter() { - self.validate_function(fun, &mod_info[handle], module) - .map_err(|error| ValidationError::Function { - handle, - name: fun.name.clone().unwrap_or_default(), - error, - })?; + match self.validate_function(fun, module, &mod_info) { + Ok(info) => mod_info.functions.push(info), + Err(error) => { + return Err(ValidationError::Function { + handle, + name: fun.name.clone().unwrap_or_default(), + error, + }) + } + } } let mut ep_map = FastHashSet::default(); - for (index, ep) in module.entry_points.iter().enumerate() { + for ep in module.entry_points.iter() { if !ep_map.insert((ep.stage, &ep.name)) { return Err(ValidationError::EntryPoint { stage: ep.stage, @@ -228,13 +236,17 @@ impl Validator { error: EntryPointError::Conflict, }); } - let info = mod_info.get_entry_point(index); - self.validate_entry_point(ep, info, module) - .map_err(|error| ValidationError::EntryPoint { - stage: ep.stage, - name: ep.name.clone(), - error, - })?; + + match self.validate_entry_point(ep, module, &mod_info) { + Ok(info) => mod_info.entry_points.push(info), + Err(error) => { + return Err(ValidationError::EntryPoint { + stage: ep.stage, + name: ep.name.clone(), + error, + }) + } + } } Ok(mod_info)