diff --git a/bin/convert.rs b/bin/convert.rs index a5870b8486..14aa6cdf7c 100644 --- a/bin/convert.rs +++ b/bin/convert.rs @@ -172,15 +172,14 @@ fn main() { // validate the IR #[allow(unused_variables)] - let analysis = match naga::proc::Validator::new(naga::proc::analyzer::AnalysisFlags::all()) - .validate(&module) - { - Ok(analysis) => Some(analysis), - Err(error) => { - print_err(error); - None - } - }; + let info = + match naga::valid::Validator::new(naga::valid::AnalysisFlags::all()).validate(&module) { + Ok(info) => Some(info), + Err(error) => { + print_err(error); + None + } + }; let output_path = match output_path { Some(ref string) => string, @@ -200,15 +199,14 @@ fn main() { "metal" => { use naga::back::msl; let (msl, _) = - msl::write_string(&module, analysis.as_ref().unwrap(), ¶ms.msl).unwrap_pretty(); + msl::write_string(&module, info.as_ref().unwrap(), ¶ms.msl).unwrap_pretty(); fs::write(output_path, msl).unwrap(); } #[cfg(feature = "spv-out")] "spv" => { use naga::back::spv; - let spv = - spv::write_vec(&module, analysis.as_ref().unwrap(), ¶ms.spv).unwrap_pretty(); + let spv = spv::write_vec(&module, info.as_ref().unwrap(), ¶ms.spv).unwrap_pretty(); let bytes = spv .iter() .fold(Vec::with_capacity(spv.len() * 4), |mut v, w| { @@ -236,9 +234,8 @@ fn main() { .open(output_path) .unwrap(); - let mut writer = - glsl::Writer::new(file, &module, analysis.as_ref().unwrap(), ¶ms.glsl) - .unwrap_pretty(); + let mut writer = glsl::Writer::new(file, &module, info.as_ref().unwrap(), ¶ms.glsl) + .unwrap_pretty(); writer .write() @@ -251,7 +248,7 @@ fn main() { #[cfg(feature = "dot-out")] "dot" => { use naga::back::dot; - let output = dot::write(&module, analysis.as_ref()).unwrap(); + let output = dot::write(&module, info.as_ref()).unwrap(); fs::write(output_path, output).unwrap(); } other => { diff --git a/src/back/dot/mod.rs b/src/back/dot/mod.rs index 8996058407..8296e0ba2a 100644 --- a/src/back/dot/mod.rs +++ b/src/back/dot/mod.rs @@ -6,7 +6,7 @@ use crate::{ arena::Handle, - proc::analyzer::{Analysis, FunctionInfo}, + valid::{FunctionInfo, ModuleInfo}, }; use std::{ @@ -429,7 +429,7 @@ fn write_fun( Ok(()) } -pub fn write(module: &crate::Module, analysis: Option<&Analysis>) -> Result { +pub fn write(module: &crate::Module, mod_info: Option<&ModuleInfo>) -> Result { use std::fmt::Write as _; let mut output = String::new(); @@ -458,7 +458,7 @@ pub fn write(module: &crate::Module, analysis: Option<&Analysis>) -> Result) -> Result { /// The module being written module: &'a Module, /// The module analysis. - analysis: &'a Analysis, + info: &'a ModuleInfo, /// The output writer out: W, /// User defined configuration to be used @@ -348,7 +346,7 @@ impl<'a, W: Write> Writer<'a, W> { pub fn new( out: W, module: &'a Module, - analysis: &'a Analysis, + info: &'a ModuleInfo, options: &'a Options, ) -> Result { // Check if the requested version is supported @@ -371,7 +369,7 @@ impl<'a, W: Write> Writer<'a, W> { // Build the instance let mut this = Self { module, - analysis, + info, out, options, @@ -473,7 +471,7 @@ impl<'a, W: Write> Writer<'a, W> { } } - let ep_info = self.analysis.get_entry_point(self.entry_point_idx as usize); + let ep_info = self.info.get_entry_point(self.entry_point_idx as usize); // Write the globals // @@ -548,13 +546,13 @@ impl<'a, W: Write> Writer<'a, W> { for (handle, function) in self.module.functions.iter() { // Check that the function doesn't use globals that aren't supported // by the current entry point - if !ep_info.dominates_global_use(&self.analysis[handle]) { + if !ep_info.dominates_global_use(&self.info[handle]) { continue; } // We also `clone` to satisfy the borrow checker let name = self.names[&NameKey::Function(handle)].clone(); - let fun_info = &self.analysis[handle]; + let fun_info = &self.info[handle]; // Write the function self.write_function(FunctionType::Function(handle), function, fun_info, &name)?; @@ -2071,7 +2069,7 @@ impl<'a, W: Write> Writer<'a, W> { /// [`Arena`](crate::arena::Arena) and we need to traverse it fn collect_reflection_info(&self) -> Result { use std::collections::hash_map::Entry; - let info = self.analysis.get_entry_point(self.entry_point_idx as usize); + let info = self.info.get_entry_point(self.entry_point_idx as usize); let mut mappings = FastHashMap::default(); let mut uniforms = FastHashMap::default(); diff --git a/src/back/msl/mod.rs b/src/back/msl/mod.rs index 9eda260cf9..b651a25d76 100644 --- a/src/back/msl/mod.rs +++ b/src/back/msl/mod.rs @@ -23,11 +23,7 @@ For the result type, if it's a structure, we re-compose it with a temporary valu holding the result. !*/ -use crate::{ - arena::Handle, - proc::{analyzer::Analysis, TypifyError}, - FastHashMap, -}; +use crate::{arena::Handle, proc::TypifyError, valid::ModuleInfo, FastHashMap}; use std::{ fmt::{Error as FmtError, Write}, string::FromUtf8Error, @@ -243,11 +239,11 @@ pub struct TranslationInfo { pub fn write_string( module: &crate::Module, - analysis: &Analysis, + info: &ModuleInfo, options: &Options, ) -> Result<(String, TranslationInfo), Error> { let mut w = writer::Writer::new(String::new()); - let info = w.write(module, analysis, options)?; + let info = w.write(module, info, options)?; Ok((w.finish(), info)) } diff --git a/src/back/msl/writer.rs b/src/back/msl/writer.rs index 2a94423525..56f4a695b5 100644 --- a/src/back/msl/writer.rs +++ b/src/back/msl/writer.rs @@ -1,10 +1,8 @@ use super::{keywords::RESERVED, Error, LocationMode, Options, TranslationInfo}; use crate::{ arena::Handle, - proc::{ - analyzer::{Analysis, FunctionInfo, GlobalUse}, - EntryPointIndex, NameKey, Namer, ResolveContext, Typifier, - }, + proc::{EntryPointIndex, NameKey, Namer, ResolveContext, Typifier}, + valid::{FunctionInfo, GlobalUse, ModuleInfo}, FastHashMap, }; use bit_set::BitSet; @@ -141,7 +139,7 @@ struct ExpressionContext<'a> { function: &'a crate::Function, origin: FunctionOrigin, module: &'a crate::Module, - analysis: &'a Analysis, + mod_info: &'a ModuleInfo, } struct StatementContext<'a> { @@ -859,7 +857,7 @@ impl Writer { } // follow-up with any global resources used let mut separate = !arguments.is_empty(); - let fun_info = &context.expression.analysis[function]; + let fun_info = &context.expression.mod_info[function]; for (handle, var) in context.expression.module.global_variables.iter() { if !fun_info[handle].is_empty() && var.class.needs_pass_through() { let name = &self.names[&NameKey::GlobalVariable(handle)]; @@ -882,7 +880,7 @@ impl Writer { pub fn write( &mut self, module: &crate::Module, - analysis: &Analysis, + info: &ModuleInfo, options: &Options, ) -> Result { self.names.clear(); @@ -894,7 +892,7 @@ impl Writer { self.write_type_defs(module)?; self.write_constants(module)?; - self.write_functions(module, analysis, options) + self.write_functions(module, info, options) } fn write_type_defs(&mut self, module: &crate::Module) -> Result<(), Error> { @@ -1106,7 +1104,7 @@ impl Writer { fn write_functions( &mut self, module: &crate::Module, - analysis: &Analysis, + mod_info: &ModuleInfo, options: &Options, ) -> Result { let mut pass_through_globals = Vec::new(); @@ -1123,7 +1121,7 @@ impl Writer { }, )?; - let fun_info = &analysis[fun_handle]; + let fun_info = &mod_info[fun_handle]; pass_through_globals.clear(); for (handle, var) in module.global_variables.iter() { if !fun_info[handle].is_empty() && var.class.needs_pass_through() { @@ -1179,7 +1177,7 @@ impl Writer { function: fun, origin: FunctionOrigin::Handle(fun_handle), module, - analysis, + mod_info, }, fun_info, result_struct: None, @@ -1195,7 +1193,7 @@ impl Writer { }; for (ep_index, ep) in module.entry_points.iter().enumerate() { let fun = &ep.function; - let fun_info = analysis.get_entry_point(ep_index); + let fun_info = mod_info.get_entry_point(ep_index); // skip this entry point if any global bindings are missing if !options.fake_missing_bindings { if let Some(err) = module @@ -1471,7 +1469,7 @@ impl Writer { function: fun, origin: FunctionOrigin::EntryPoint(ep_index as _), module, - analysis, + mod_info, }, fun_info, result_struct: Some(&stage_out_name), @@ -1490,7 +1488,7 @@ impl Writer { #[test] fn test_stack_size() { - use crate::proc::analyzer::AnalysisFlags; + use crate::valid::AnalysisFlags; // create a module with at least one expression nested let mut module = crate::Module::default(); let constant = module.constants.append(crate::Constant { @@ -1518,12 +1516,10 @@ fn test_stack_size() { }); let _ = module.functions.append(fun); // analyse the module - let analysis = Analysis::new(&module, AnalysisFlags::empty()).unwrap(); + let info = ModuleInfo::new(&module, AnalysisFlags::empty()).unwrap(); // process the module let mut writer = Writer::new(String::new()); - writer - .write(&module, &analysis, &Default::default()) - .unwrap(); + writer.write(&module, &info, &Default::default()).unwrap(); let (mut min_addr, mut max_addr) = (!0usize, 0usize); for pointer in writer.put_expression_stack_pointers { min_addr = min_addr.min(pointer as usize); diff --git a/src/back/spv/mod.rs b/src/back/spv/mod.rs index 4a780f2efd..4251b5cb41 100644 --- a/src/back/spv/mod.rs +++ b/src/back/spv/mod.rs @@ -73,11 +73,11 @@ impl Default for Options { pub fn write_vec( module: &crate::Module, - analysis: &crate::proc::analyzer::Analysis, + info: &crate::valid::ModuleInfo, options: &Options, ) -> Result, Error> { let mut words = Vec::new(); let mut w = Writer::new(options)?; - w.write(module, analysis, &mut words)?; + w.write(module, info, &mut words)?; Ok(words) } diff --git a/src/back/spv/writer.rs b/src/back/spv/writer.rs index b75db15733..912a840a1d 100644 --- a/src/back/spv/writer.rs +++ b/src/back/spv/writer.rs @@ -2,10 +2,8 @@ use super::{Instruction, LogicalLayout, Options, PhysicalLayout, WriterFlags}; use crate::{ arena::{Arena, Handle}, - proc::{ - analyzer::{Analysis, FunctionInfo}, - Layouter, ResolveContext, Typifier, TypifyError, - }, + proc::{Layouter, ResolveContext, Typifier, TypifyError}, + valid::{FunctionInfo, ModuleInfo}, }; use spirv::Word; use std::{collections::hash_map::Entry, ops}; @@ -2350,7 +2348,7 @@ impl Writer { fn write_logical_layout( &mut self, ir_module: &crate::Module, - analysis: &Analysis, + mod_info: &ModuleInfo, ) -> Result<(), Error> { Instruction::type_void(self.void_type).to_words(&mut self.logical_layout.declarations); Instruction::ext_inst_import(self.gl450_ext_inst_id, "GLSL.std.450") @@ -2387,13 +2385,13 @@ impl Writer { } for (handle, ir_function) in ir_module.functions.iter() { - let info = &analysis[handle]; + let info = &mod_info[handle]; let id = self.write_function(ir_function, info, ir_module, None)?; self.lookup_function.insert(handle, id); } for (ep_index, ir_ep) in ir_module.entry_points.iter().enumerate() { - let info = analysis.get_entry_point(ep_index); + let info = mod_info.get_entry_point(ep_index); let ep_instruction = self.write_entry_point(ir_ep, info, ir_module)?; ep_instruction.to_words(&mut self.logical_layout.entry_points); } @@ -2426,7 +2424,7 @@ impl Writer { pub fn write( &mut self, ir_module: &crate::Module, - analysis: &Analysis, + info: &ModuleInfo, words: &mut Vec, ) -> Result<(), Error> { self.lookup_function.clear(); @@ -2436,7 +2434,7 @@ impl Writer { self.layouter .initialize(&ir_module.types, &ir_module.constants); - self.write_logical_layout(ir_module, analysis)?; + self.write_logical_layout(ir_module, info)?; self.write_physical_layout(); self.physical_layout.in_words(words); diff --git a/src/lib.rs b/src/lib.rs index 1df4e0004e..4f1992d3d9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -40,6 +40,7 @@ mod arena; pub mod back; pub mod front; pub mod proc; +pub mod valid; pub use crate::arena::{Arena, Handle, Range}; diff --git a/src/proc/mod.rs b/src/proc/mod.rs index c4ad803af3..c1eea6a2aa 100644 --- a/src/proc/mod.rs +++ b/src/proc/mod.rs @@ -1,17 +1,14 @@ //! Module processing functionality. -pub mod analyzer; mod layouter; mod namer; mod terminator; mod typifier; -mod validator; pub use layouter::{Alignment, Layouter}; pub use namer::{EntryPointIndex, NameKey, Namer}; pub use terminator::ensure_block_returns; pub use typifier::{ResolveContext, ResolveError, Typifier, TypifyError}; -pub use validator::{TypeFlags, ValidationError, Validator}; impl From for super::ScalarKind { fn from(format: super::StorageFormat) -> Self { diff --git a/src/proc/validator.rs b/src/proc/validator.rs deleted file mode 100644 index 837fed50b4..0000000000 --- a/src/proc/validator.rs +++ /dev/null @@ -1,1829 +0,0 @@ -use super::{ - analyzer::{Analysis, AnalysisError, AnalysisFlags, FunctionInfo, GlobalUse}, - typifier::{ResolveContext, Typifier, TypifyError}, -}; -use crate::{ - arena::{Arena, Handle}, - FastHashSet, -}; -use bit_set::BitSet; -use thiserror::Error; - -const MAX_WORKGROUP_SIZE: u32 = 0x4000; - -bitflags::bitflags! { - #[repr(transparent)] - pub struct TypeFlags: u8 { - /// Can be used for data variables. - const DATA = 0x1; - /// The data type has known size. - const SIZED = 0x2; - /// Can be be used for interfacing between pipeline stages. - const INTERFACE = 0x4; - /// Can be used for host-shareable structures. - const HOST_SHARED = 0x8; - } -} - -bitflags::bitflags! { - #[repr(transparent)] - pub struct BlockFlags: u8 { - /// The control can jump out of this block. - const CAN_JUMP = 0x1; - /// The control is in a loop, can break and continue. - const IN_LOOP = 0x2; - } -} - -struct BlockContext<'a> { - flags: BlockFlags, - expressions: &'a Arena, - types: &'a Arena, - functions: &'a Arena, - return_type: Option>, -} - -impl<'a> BlockContext<'a> { - fn with_flags(&self, flags: BlockFlags) -> Self { - BlockContext { - flags, - expressions: self.expressions, - types: self.types, - functions: self.functions, - return_type: self.return_type, - } - } - - fn get_expression( - &self, - handle: Handle, - ) -> Result<&'a crate::Expression, FunctionError> { - self.expressions - .try_get(handle) - .ok_or(FunctionError::InvalidExpression(handle)) - } -} - -#[derive(Debug)] -pub struct Validator { - analysis_flags: AnalysisFlags, - //Note: this is a bit tricky: some of the front-ends as well as backends - // already have to use the typifier, so the work here is redundant in a way. - typifier: Typifier, - type_flags: Vec, - location_mask: BitSet, - bind_group_masks: Vec, - select_cases: FastHashSet, - valid_expression_list: Vec>, - valid_expression_set: BitSet, -} - -#[derive(Clone, Debug, Error)] -pub enum TypeError { - #[error("The {0:?} scalar width {1} is not supported")] - InvalidWidth(crate::ScalarKind, crate::Bytes), - #[error("The base handle {0:?} can not be resolved")] - UnresolvedBase(Handle), - #[error("Expected data type, found {0:?}")] - InvalidData(Handle), - #[error("Structure type {0:?} can not be a block structure")] - InvalidBlockType(Handle), - #[error("Base type {0:?} for the array is invalid")] - InvalidArrayBaseType(Handle), - #[error("The constant {0:?} can not be used for an array size")] - InvalidArraySizeConstant(Handle), - #[error("Field '{0}' can't be dynamically-sized, has type {1:?}")] - InvalidDynamicArray(String, Handle), -} - -#[derive(Clone, Debug, Error)] -pub enum ConstantError { - #[error("The type doesn't match the constant")] - InvalidType, - #[error("The component handle {0:?} can not be resolved")] - UnresolvedComponent(Handle), - #[error("The array size handle {0:?} can not be resolved")] - UnresolvedSize(Handle), -} - -#[derive(Clone, Debug, Error)] -pub enum GlobalVariableError { - #[error("Usage isn't compatible with the storage class")] - InvalidUsage, - #[error("Type isn't compatible with the storage class")] - InvalidType, - #[error("Storage access {seen:?} exceeds the allowed {allowed:?}")] - InvalidStorageAccess { - allowed: crate::StorageAccess, - seen: crate::StorageAccess, - }, - #[error("Type flags {seen:?} do not meet the required {required:?}")] - MissingTypeFlags { - required: TypeFlags, - seen: TypeFlags, - }, - #[error("Binding decoration is missing or not applicable")] - InvalidBinding, -} - -#[derive(Clone, Debug, Error)] -pub enum LocalVariableError { - #[error("Initializer doesn't match the variable type")] - InitializerType, -} - -#[derive(Clone, Debug, Error)] -pub enum VaryingError { - #[error("The type {0:?} does not match the varying")] - InvalidType(Handle), - #[error("Interpolation is not valid")] - InvalidInterpolation, - #[error("BuiltIn {0:?} is not available at this stage")] - InvalidBuiltInStage(crate::BuiltIn), - #[error("BuiltIn type for {0:?} is invalid")] - InvalidBuiltInType(crate::BuiltIn), - #[error("Struct member {0} is missing a binding")] - MemberMissingBinding(u32), - #[error("Multiple bindings at location {location} are present")] - BindingCollision { location: u32 }, -} - -#[derive(Clone, Debug, Error)] -pub enum ExpressionError { - #[error("Doesn't exist")] - DoesntExist, - #[error("Used by a statement before it was introduced into the scope by any of the dominating blocks")] - NotInScope, - #[error("Depends on {0:?}, which has not been processed yet")] - ForwardDependency(Handle), - #[error("Base type {0:?} is not compatible with this expression")] - InvalidBaseType(Handle), - #[error("Accessing with index {0:?} can't be done")] - InvalidIndexType(Handle), - #[error("Accessing index {1} is out of {0:?} bounds")] - IndexOutOfBounds(Handle, u32), - #[error("Function argument {0:?} doesn't exist")] - FunctionArgumentDoesntExist(u32), - #[error("Constant {0:?} doesn't exist")] - ConstantDoesntExist(Handle), - #[error("Global variable {0:?} doesn't exist")] - GlobalVarDoesntExist(Handle), - #[error("Local variable {0:?} doesn't exist")] - LocalVarDoesntExist(Handle), - #[error("Loading of {0:?} can't be done")] - InvalidPointerType(Handle), - #[error("Array length of {0:?} can't be done")] - InvalidArrayType(Handle), - #[error("Compose type {0:?} doesn't exist")] - ComposeTypeDoesntExist(Handle), - #[error("Composing of type {0:?} can't be done")] - InvalidComposeType(Handle), - #[error("Composing expects {expected} components but {given} were given")] - InvalidComposeCount { given: u32, expected: u32 }, - #[error("Composing {0}'s component {1:?} is not expected")] - InvalidComponentType(u32, Handle), - #[error("Operation {0:?} can't work with {1:?}")] - InvalidUnaryOperandType(crate::UnaryOperator, Handle), - #[error("Operation {0:?} can't work with {1:?} and {2:?}")] - InvalidBinaryOperandTypes( - crate::BinaryOperator, - Handle, - Handle, - ), - #[error("Selecting is not possible")] - InvalidSelectTypes, - #[error("Relational argument {0:?} is not a boolean vector")] - InvalidBooleanVector(Handle), - #[error("Relational argument {0:?} is not a float")] - InvalidFloatArgument(Handle), -} - -#[derive(Clone, Debug, Error)] -pub enum CallError { - #[error("Bad function")] - InvalidFunction, - #[error("The callee is declared after the caller")] - ForwardDeclaredFunction, - #[error("Argument {index} expression is invalid")] - Argument { - index: usize, - #[source] - error: ExpressionError, - }, - #[error("Result expression {0:?} has already been introduced earlier")] - ResultAlreadyInScope(Handle), - #[error("Result value is invalid")] - ResultValue(#[source] ExpressionError), - #[error("Requires {required} arguments, but {seen} are provided")] - ArgumentCount { required: usize, seen: usize }, - #[error("Argument {index} value {seen_expression:?} doesn't match the type {required:?}")] - ArgumentType { - index: usize, - required: Handle, - seen_expression: Handle, - }, - #[error("Result value {seen_expression:?} does not match the type {required:?}")] - ResultType { - required: Option>, - seen_expression: Option>, - }, -} - -#[derive(Clone, Debug, Error)] -pub enum FunctionError { - #[error(transparent)] - Resolve(#[from] TypifyError), - #[error("Expression {handle:?} is invalid")] - Expression { - handle: Handle, - #[source] - error: ExpressionError, - }, - #[error("Expression {0:?} can't be introduced - it's already in scope")] - ExpressionAlreadyInScope(Handle), - #[error("Local variable {handle:?} '{name}' is invalid")] - LocalVariable { - handle: Handle, - name: String, - #[source] - error: LocalVariableError, - }, - #[error("Argument '{name}' at index {index} has a type that can't be passed into functions.")] - InvalidArgumentType { index: usize, name: String }, - #[error("There are instructions after `return`/`break`/`continue`")] - InstructionsAfterReturn, - #[error("The `break`/`continue` is used outside of a loop context")] - BreakContinueOutsideOfLoop, - #[error("The `return` is called within a `continuing` block")] - InvalidReturnSpot, - #[error("The `return` value {0:?} does not match the function return value")] - InvalidReturnType(Option>), - #[error("The `if` condition {0:?} is not a boolean scalar")] - InvalidIfType(Handle), - #[error("The `switch` value {0:?} is not an integer scalar")] - InvalidSwitchType(Handle), - #[error("Multiple `switch` cases for {0} are present")] - ConflictingSwitchCase(i32), - #[error("The pointer {0:?} doesn't relate to a valid destination for a store")] - InvalidStorePointer(Handle), - #[error("The value {0:?} can not be stored")] - InvalidStoreValue(Handle), - #[error("Store of {value:?} into {pointer:?} doesn't have matching types")] - InvalidStoreTypes { - pointer: Handle, - value: Handle, - }, - #[error("The image array can't be indexed by {0:?}")] - InvalidArrayIndex(Handle), - #[error("The expression {0:?} is currupted")] - InvalidExpression(Handle), - #[error("The expression {0:?} is not an image")] - InvalidImage(Handle), - #[error("Call to {function:?} is invalid")] - InvalidCall { - function: Handle, - #[source] - error: CallError, - }, -} - -#[derive(Clone, Debug, Error)] -pub enum EntryPointError { - #[error("Multiple conflicting entry points")] - Conflict, - #[error("Early depth test is not applicable")] - UnexpectedEarlyDepthTest, - #[error("Workgroup size is not applicable")] - UnexpectedWorkgroupSize, - #[error("Workgroup size is out of range")] - OutOfRangeWorkgroupSize, - #[error("Global variable {0:?} is used incorrectly as {1:?}")] - InvalidGlobalUsage(Handle, GlobalUse), - #[error("Bindings for {0:?} conflict with other resource")] - BindingCollision(Handle), - #[error("Argument {0} varying error")] - Argument(u32, #[source] VaryingError), - #[error("Result varying error")] - Result(#[source] VaryingError), - #[error("Location {location} onterpolation of an integer has to be flat")] - InvalidIntegerInterpolation { location: u32 }, - #[error(transparent)] - Function(#[from] FunctionError), -} - -#[derive(Clone, Debug, Error)] -pub enum ValidationError { - #[error("Type {handle:?} '{name}' is invalid")] - Type { - handle: Handle, - name: String, - #[source] - error: TypeError, - }, - #[error("Constant {handle:?} '{name}' is invalid")] - Constant { - handle: Handle, - name: String, - #[source] - error: ConstantError, - }, - #[error("Global variable {handle:?} '{name}' is invalid")] - GlobalVariable { - handle: Handle, - name: String, - #[source] - error: GlobalVariableError, - }, - #[error("Function {handle:?} '{name}' is invalid")] - Function { - handle: Handle, - name: String, - #[source] - error: FunctionError, - }, - #[error("Entry point {name} at {stage:?} is invalid")] - EntryPoint { - stage: crate::ShaderStage, - name: String, - #[source] - error: EntryPointError, - }, - #[error(transparent)] - Analysis(#[from] AnalysisError), - #[error("Module is corrupted")] - Corrupted, -} - -fn storage_usage(access: crate::StorageAccess) -> GlobalUse { - let mut storage_usage = GlobalUse::QUERY; - if access.contains(crate::StorageAccess::LOAD) { - storage_usage |= GlobalUse::READ; - } - if access.contains(crate::StorageAccess::STORE) { - storage_usage |= GlobalUse::WRITE; - } - storage_usage -} - -impl crate::TypeInner { - fn is_sized(&self) -> bool { - match *self { - Self::Scalar { .. } - | Self::Vector { .. } - | Self::Matrix { .. } - | Self::Array { - size: crate::ArraySize::Constant(_), - .. - } - | Self::Pointer { .. } - | Self::ValuePointer { .. } - | Self::Struct { .. } => true, - Self::Array { .. } | Self::Image { .. } | Self::Sampler { .. } => false, - } - } -} - -struct VaryingContext<'a> { - ty: Handle, - stage: crate::ShaderStage, - output: bool, - types: &'a Arena, - location_mask: &'a mut BitSet, -} - -impl VaryingContext<'_> { - fn validate_impl(&mut self, binding: &crate::Binding) -> Result<(), VaryingError> { - use crate::{ - BuiltIn as Bi, ScalarKind as Sk, ShaderStage as St, TypeInner as Ti, VectorSize as Vs, - }; - - let ty_inner = &self.types[self.ty].inner; - match *binding { - crate::Binding::BuiltIn(built_in) => { - let width = 4; - let (visible, type_good) = match built_in { - Bi::BaseInstance | Bi::BaseVertex | Bi::InstanceIndex | Bi::VertexIndex => ( - self.stage == St::Vertex && !self.output, - *ty_inner - == Ti::Scalar { - kind: Sk::Uint, - width, - }, - ), - Bi::ClipDistance => ( - self.stage == St::Vertex && self.output, - match *ty_inner { - Ti::Array { base, .. } => { - self.types[base].inner - == Ti::Scalar { - kind: Sk::Float, - width, - } - } - _ => false, - }, - ), - Bi::PointSize => ( - self.stage == St::Vertex && self.output, - *ty_inner - == Ti::Scalar { - kind: Sk::Float, - width, - }, - ), - Bi::Position => ( - match self.stage { - St::Vertex => self.output, - St::Fragment => !self.output, - St::Compute => false, - }, - *ty_inner - == Ti::Vector { - size: Vs::Quad, - kind: Sk::Float, - width, - }, - ), - Bi::FragDepth => ( - self.stage == St::Fragment && self.output, - *ty_inner - == Ti::Scalar { - kind: Sk::Float, - width, - }, - ), - Bi::FrontFacing => ( - self.stage == St::Fragment && !self.output, - *ty_inner - == Ti::Scalar { - kind: Sk::Bool, - width: crate::BOOL_WIDTH, - }, - ), - Bi::SampleIndex => ( - self.stage == St::Fragment && !self.output, - *ty_inner - == Ti::Scalar { - kind: Sk::Uint, - width, - }, - ), - Bi::SampleMask => ( - self.stage == St::Fragment, - *ty_inner - == Ti::Scalar { - kind: Sk::Uint, - width, - }, - ), - Bi::LocalInvocationIndex => ( - self.stage == St::Compute && !self.output, - *ty_inner - == Ti::Scalar { - kind: Sk::Uint, - width, - }, - ), - Bi::GlobalInvocationId - | Bi::LocalInvocationId - | Bi::WorkGroupId - | Bi::WorkGroupSize => ( - self.stage == St::Compute && !self.output, - *ty_inner - == Ti::Vector { - size: Vs::Tri, - kind: Sk::Uint, - width, - }, - ), - }; - - if !visible { - return Err(VaryingError::InvalidBuiltInStage(built_in)); - } - if !type_good { - log::warn!("Wrong builtin type: {:?}", ty_inner); - return Err(VaryingError::InvalidBuiltInType(built_in)); - } - } - crate::Binding::Location(location, interpolation) => { - if !self.location_mask.insert(location as usize) { - return Err(VaryingError::BindingCollision { location }); - } - let needs_interpolation = - self.stage == crate::ShaderStage::Fragment && !self.output; - if !needs_interpolation && interpolation.is_some() { - return Err(VaryingError::InvalidInterpolation); - } - match ty_inner.scalar_kind() { - Some(crate::ScalarKind::Float) => {} - Some(_) - if needs_interpolation - && interpolation != Some(crate::Interpolation::Flat) => - { - return Err(VaryingError::InvalidInterpolation); - } - Some(_) => {} - None => return Err(VaryingError::InvalidType(self.ty)), - } - } - } - - Ok(()) - } - - fn validate(mut self, binding: Option<&crate::Binding>) -> Result<(), VaryingError> { - match binding { - Some(binding) => self.validate_impl(binding), - None => { - match self.types[self.ty].inner { - //TODO: check the member types - crate::TypeInner::Struct { - block: false, - ref members, - } => { - for (index, member) in members.iter().enumerate() { - self.ty = member.ty; - match member.binding { - None => { - return Err(VaryingError::MemberMissingBinding(index as u32)) - } - Some(ref binding) => self.validate_impl(binding)?, - } - } - } - _ => return Err(VaryingError::InvalidType(self.ty)), - } - Ok(()) - } - } - } -} - -struct ExpressionTypeResolver<'a> { - root: Handle, - types: &'a Arena, - typifier: &'a Typifier, -} - -impl<'a> ExpressionTypeResolver<'a> { - fn resolve( - &self, - handle: Handle, - ) -> Result<&'a crate::TypeInner, ExpressionError> { - if handle < self.root { - Ok(self.typifier.get(handle, self.types)) - } else { - Err(ExpressionError::ForwardDependency(handle)) - } - } -} - -impl Validator { - /// Construct a new validator instance. - pub fn new(analysis_flags: AnalysisFlags) -> Self { - Validator { - analysis_flags, - typifier: Typifier::new(), - type_flags: Vec::new(), - location_mask: BitSet::new(), - bind_group_masks: Vec::new(), - select_cases: FastHashSet::default(), - valid_expression_list: Vec::new(), - valid_expression_set: BitSet::new(), - } - } - - fn check_width(kind: crate::ScalarKind, width: crate::Bytes) -> bool { - match kind { - crate::ScalarKind::Bool => width == crate::BOOL_WIDTH, - _ => width == 4, - } - } - - fn validate_type( - &self, - ty: &crate::Type, - handle: Handle, - constants: &Arena, - ) -> Result { - use crate::TypeInner as Ti; - Ok(match ty.inner { - Ti::Scalar { kind, width } | Ti::Vector { kind, width, .. } => { - if !Self::check_width(kind, width) { - return Err(TypeError::InvalidWidth(kind, width)); - } - TypeFlags::DATA | TypeFlags::SIZED | TypeFlags::INTERFACE | TypeFlags::HOST_SHARED - } - Ti::Matrix { width, .. } => { - if !Self::check_width(crate::ScalarKind::Float, width) { - return Err(TypeError::InvalidWidth(crate::ScalarKind::Float, width)); - } - TypeFlags::DATA | TypeFlags::SIZED | TypeFlags::INTERFACE | TypeFlags::HOST_SHARED - } - Ti::Pointer { base, class: _ } => { - if base >= handle { - return Err(TypeError::UnresolvedBase(base)); - } - TypeFlags::DATA | TypeFlags::SIZED - } - Ti::ValuePointer { - size: _, - kind, - width, - class: _, - } => { - if !Self::check_width(kind, width) { - return Err(TypeError::InvalidWidth(kind, width)); - } - TypeFlags::SIZED //TODO: `DATA`? - } - Ti::Array { - base, - size, - stride: _, - } => { - if base >= handle { - return Err(TypeError::UnresolvedBase(base)); - } - let base_flags = self.type_flags[base.index()]; - if !base_flags.contains(TypeFlags::DATA | TypeFlags::SIZED) { - return Err(TypeError::InvalidArrayBaseType(base)); - } - - let sized_flag = match size { - crate::ArraySize::Constant(const_handle) => { - match constants.try_get(const_handle) { - Some(&crate::Constant { - inner: - crate::ConstantInner::Scalar { - width: _, - value: crate::ScalarValue::Uint(_), - }, - .. - }) => {} - // Accept a signed integer size to avoid - // requiring an explicit uint - // literal. Type inference should make - // this unnecessary. - Some(&crate::Constant { - inner: - crate::ConstantInner::Scalar { - width: _, - value: crate::ScalarValue::Sint(_), - }, - .. - }) => {} - other => { - log::warn!("Array size {:?}", other); - return Err(TypeError::InvalidArraySizeConstant(const_handle)); - } - } - TypeFlags::SIZED - } - crate::ArraySize::Dynamic => TypeFlags::empty(), - }; - let base_mask = TypeFlags::HOST_SHARED | TypeFlags::INTERFACE; - TypeFlags::DATA | (base_flags & base_mask) | sized_flag - } - Ti::Struct { block, ref members } => { - let mut flags = TypeFlags::all(); - for (i, member) in members.iter().enumerate() { - if member.ty >= handle { - return Err(TypeError::UnresolvedBase(member.ty)); - } - let base_flags = self.type_flags[member.ty.index()]; - flags &= base_flags; - if !base_flags.contains(TypeFlags::DATA) { - return Err(TypeError::InvalidData(member.ty)); - } - if block && !base_flags.contains(TypeFlags::INTERFACE) { - return Err(TypeError::InvalidBlockType(member.ty)); - } - // only the last field can be unsized - if i + 1 != members.len() && !base_flags.contains(TypeFlags::SIZED) { - let name = member.name.clone().unwrap_or_default(); - return Err(TypeError::InvalidDynamicArray(name, member.ty)); - } - } - //TODO: check the spans - flags - } - Ti::Image { .. } | Ti::Sampler { .. } => TypeFlags::empty(), - }) - } - - fn validate_constant( - &self, - handle: Handle, - constants: &Arena, - types: &Arena, - ) -> Result<(), ConstantError> { - let con = &constants[handle]; - match con.inner { - crate::ConstantInner::Scalar { width, ref value } => { - if !Self::check_width(value.scalar_kind(), width) { - return Err(ConstantError::InvalidType); - } - } - crate::ConstantInner::Composite { ty, ref components } => { - match types[ty].inner { - crate::TypeInner::Array { - size: crate::ArraySize::Dynamic, - .. - } => { - return Err(ConstantError::InvalidType); - } - crate::TypeInner::Array { - size: crate::ArraySize::Constant(size_handle), - .. - } => { - if handle <= size_handle { - return Err(ConstantError::UnresolvedSize(size_handle)); - } - } - _ => {} //TODO - } - if let Some(&comp) = components.iter().find(|&&comp| handle <= comp) { - return Err(ConstantError::UnresolvedComponent(comp)); - } - } - } - Ok(()) - } - - fn validate_global_var( - &self, - var: &crate::GlobalVariable, - types: &Arena, - ) -> Result<(), GlobalVariableError> { - log::debug!("var {:?}", var); - let (allowed_storage_access, required_type_flags, is_resource) = match var.class { - crate::StorageClass::Function => return Err(GlobalVariableError::InvalidUsage), - crate::StorageClass::Storage => { - match types[var.ty].inner { - crate::TypeInner::Struct { .. } => (), - _ => return Err(GlobalVariableError::InvalidType), - } - ( - crate::StorageAccess::all(), - TypeFlags::DATA | TypeFlags::HOST_SHARED, - true, - ) - } - crate::StorageClass::Uniform => { - match types[var.ty].inner { - crate::TypeInner::Struct { .. } => (), - _ => return Err(GlobalVariableError::InvalidType), - } - ( - crate::StorageAccess::empty(), - TypeFlags::DATA | TypeFlags::SIZED | TypeFlags::HOST_SHARED, - true, - ) - } - crate::StorageClass::Handle => { - let access = match types[var.ty].inner { - crate::TypeInner::Image { - class: crate::ImageClass::Storage(_), - .. - } => crate::StorageAccess::all(), - crate::TypeInner::Image { .. } | crate::TypeInner::Sampler { .. } => { - crate::StorageAccess::empty() - } - _ => return Err(GlobalVariableError::InvalidType), - }; - (access, TypeFlags::empty(), true) - } - crate::StorageClass::Private | crate::StorageClass::WorkGroup => { - (crate::StorageAccess::empty(), TypeFlags::DATA, false) - } - crate::StorageClass::PushConstant => ( - crate::StorageAccess::LOAD, - TypeFlags::DATA | TypeFlags::HOST_SHARED, - false, - ), - }; - - if !allowed_storage_access.contains(var.storage_access) { - return Err(GlobalVariableError::InvalidStorageAccess { - seen: var.storage_access, - allowed: allowed_storage_access, - }); - } - - let type_flags = self.type_flags[var.ty.index()]; - if !type_flags.contains(required_type_flags) { - return Err(GlobalVariableError::MissingTypeFlags { - seen: type_flags, - required: required_type_flags, - }); - } - - if is_resource != var.binding.is_some() { - return Err(GlobalVariableError::InvalidBinding); - } - - Ok(()) - } - - fn validate_local_var( - &self, - var: &crate::LocalVariable, - types: &Arena, - constants: &Arena, - ) -> Result<(), LocalVariableError> { - log::debug!("var {:?}", var); - if let Some(const_handle) = var.init { - match constants[const_handle].inner { - crate::ConstantInner::Scalar { width, ref value } => { - let ty_inner = crate::TypeInner::Scalar { - width, - kind: value.scalar_kind(), - }; - if types[var.ty].inner != ty_inner { - return Err(LocalVariableError::InitializerType); - } - } - crate::ConstantInner::Composite { ty, components: _ } => { - if ty != var.ty { - return Err(LocalVariableError::InitializerType); - } - } - } - } - Ok(()) - } - - fn validate_call( - &mut self, - function: Handle, - arguments: &[Handle], - result: Option>, - context: &BlockContext, - ) -> Result<(), CallError> { - let fun = context - .functions - .try_get(function) - .ok_or(CallError::InvalidFunction)?; - if fun.arguments.len() != arguments.len() { - return Err(CallError::ArgumentCount { - required: fun.arguments.len(), - seen: arguments.len(), - }); - } - for (index, (arg, &expr)) in fun.arguments.iter().zip(arguments).enumerate() { - let ty = self - .resolve_statement_type_impl(expr, context.types) - .map_err(|error| CallError::Argument { index, error })?; - if ty != &context.types[arg.ty].inner { - return Err(CallError::ArgumentType { - index, - required: arg.ty, - seen_expression: expr, - }); - } - } - - if let Some(expr) = result { - if self.valid_expression_set.insert(expr.index()) { - self.valid_expression_list.push(expr); - } else { - return Err(CallError::ResultAlreadyInScope(expr)); - } - } - - let result_ty = result - .map(|expr| self.resolve_statement_type_impl(expr, context.types)) - .transpose() - .map_err(CallError::ResultValue)?; - let expected_ty = fun.result.as_ref().map(|fr| &context.types[fr.ty].inner); - if result_ty != expected_ty { - log::error!( - "Called function returns {:?} where {:?} is expected", - result_ty, - expected_ty - ); - return Err(CallError::ResultType { - required: fun.result.as_ref().map(|fr| fr.ty), - seen_expression: result, - }); - } - Ok(()) - } - - #[allow(unused)] - fn validate_expression( - &self, - root: Handle, - expression: &crate::Expression, - function: &crate::Function, - module: &crate::Module, - ) -> Result<(), ExpressionError> { - use crate::{Expression as E, ScalarKind as Sk, TypeInner as Ti}; - - let resolver = ExpressionTypeResolver { - root, - types: &module.types, - typifier: &self.typifier, - }; - - match *expression { - E::Access { base, index } => { - match *resolver.resolve(base)? { - Ti::Vector { .. } - | Ti::Matrix { .. } - | Ti::Array { .. } - | Ti::Pointer { .. } => {} - ref other => { - log::error!("Indexing of {:?}", other); - return Err(ExpressionError::InvalidBaseType(base)); - } - } - match *resolver.resolve(index)? { - //TODO: only allow one of these - Ti::Scalar { - kind: Sk::Sint, - width: _, - } - | Ti::Scalar { - kind: Sk::Uint, - width: _, - } => {} - ref other => { - log::error!("Indexing by {:?}", other); - return Err(ExpressionError::InvalidIndexType(index)); - } - } - } - E::AccessIndex { base, index } => { - let limit = match *resolver.resolve(base)? { - Ti::Vector { size, .. } => size as u32, - Ti::Matrix { columns, .. } => columns as u32, - Ti::Array { - size: crate::ArraySize::Constant(handle), - .. - } => module.constants[handle].to_array_length().unwrap(), - Ti::Array { .. } => !0, // can't statically know, but need run-time checks - Ti::Pointer { .. } => !0, //TODO - Ti::Struct { - ref members, - block: _, - } => members.len() as u32, - ref other => { - log::error!("Indexing of {:?}", other); - return Err(ExpressionError::InvalidBaseType(base)); - } - }; - if index >= limit { - return Err(ExpressionError::IndexOutOfBounds(base, index)); - } - } - E::Constant(handle) => { - let _ = module - .constants - .try_get(handle) - .ok_or(ExpressionError::ConstantDoesntExist(handle))?; - } - E::Compose { ref components, ty } => { - match module - .types - .try_get(ty) - .ok_or(ExpressionError::ComposeTypeDoesntExist(ty))? - .inner - { - // vectors are composed from scalars or other vectors - Ti::Vector { size, kind, width } => { - let inner = Ti::Scalar { kind, width }; - let mut total = 0; - for (index, &comp) in components.iter().enumerate() { - total += match *resolver.resolve(comp)? { - Ti::Scalar { - kind: comp_kind, - width: comp_width, - } if comp_kind == kind && comp_width == width => 1, - Ti::Vector { - size: comp_size, - kind: comp_kind, - width: comp_width, - } if comp_kind == kind && comp_width == width => comp_size as u32, - ref other => { - log::error!("Vector component[{}] type {:?}", index, other); - return Err(ExpressionError::InvalidComponentType( - index as u32, - comp, - )); - } - }; - } - if size as u32 != total { - return Err(ExpressionError::InvalidComposeCount { - expected: size as u32, - given: total, - }); - } - } - // matrix are composed from column vectors - Ti::Matrix { - columns, - rows, - width, - } => { - let inner = Ti::Vector { - size: rows, - kind: Sk::Float, - width, - }; - if columns as usize != components.len() { - return Err(ExpressionError::InvalidComposeCount { - expected: columns as u32, - given: components.len() as u32, - }); - } - for (index, &comp) in components.iter().enumerate() { - let tin = resolver.resolve(comp)?; - if tin != &inner { - log::error!("Matrix component[{}] type {:?}", index, tin); - return Err(ExpressionError::InvalidComponentType( - index as u32, - comp, - )); - } - } - } - Ti::Array { - base, - size: crate::ArraySize::Constant(handle), - stride: _, - } => { - let count = module.constants[handle].to_array_length().unwrap(); - if count as usize != components.len() { - return Err(ExpressionError::InvalidComposeCount { - expected: count, - given: components.len() as u32, - }); - } - let base_inner = &module.types[base].inner; - for (index, &comp) in components.iter().enumerate() { - let tin = resolver.resolve(comp)?; - if tin != base_inner { - log::error!("Array component[{}] type {:?}", index, tin); - return Err(ExpressionError::InvalidComponentType( - index as u32, - comp, - )); - } - } - } - Ti::Struct { - block: _, - ref members, - } => { - for (index, (member, &comp)) in members.iter().zip(components).enumerate() { - let tin = resolver.resolve(comp)?; - if tin != &module.types[member.ty].inner { - log::error!("Struct component[{}] type {:?}", index, tin); - return Err(ExpressionError::InvalidComponentType( - index as u32, - comp, - )); - } - } - if members.len() != components.len() { - return Err(ExpressionError::InvalidComposeCount { - given: components.len() as u32, - expected: members.len() as u32, - }); - } - } - ref other => { - log::error!("Composing of {:?}", other); - return Err(ExpressionError::InvalidComposeType(ty)); - } - } - } - E::FunctionArgument(index) => { - if index >= function.arguments.len() as u32 { - return Err(ExpressionError::FunctionArgumentDoesntExist(index)); - } - } - E::GlobalVariable(handle) => { - let _ = module - .global_variables - .try_get(handle) - .ok_or(ExpressionError::GlobalVarDoesntExist(handle))?; - } - E::LocalVariable(handle) => { - let _ = function - .local_variables - .try_get(handle) - .ok_or(ExpressionError::LocalVarDoesntExist(handle))?; - } - E::Load { pointer } => match *resolver.resolve(pointer)? { - Ti::Pointer { .. } | Ti::ValuePointer { .. } => {} - ref other => { - log::error!("Loading {:?}", other); - return Err(ExpressionError::InvalidPointerType(pointer)); - } - }, - E::ImageSample { - image, - sampler, - coordinate, - array_index, - offset, - level, - depth_ref, - } => {} - E::ImageLoad { - image, - coordinate, - array_index, - index, - } => {} - E::ImageQuery { image, query } => {} - E::Unary { op, expr } => { - use crate::UnaryOperator as Uo; - let inner = resolver.resolve(expr)?; - match (op, inner.scalar_kind()) { - (_, Some(Sk::Sint)) - | (_, Some(Sk::Bool)) - | (Uo::Negate, Some(Sk::Float)) - | (Uo::Not, Some(Sk::Uint)) => {} - other => { - log::error!("Op {:?} kind {:?}", op, other); - return Err(ExpressionError::InvalidUnaryOperandType(op, expr)); - } - } - } - E::Binary { op, left, right } => { - use crate::BinaryOperator as Bo; - let left_inner = resolver.resolve(left)?; - let right_inner = resolver.resolve(right)?; - let good = match op { - Bo::Add | Bo::Subtract | Bo::Divide | Bo::Modulo => match *left_inner { - Ti::Scalar { kind, .. } | Ti::Vector { kind, .. } => match kind { - Sk::Uint | Sk::Sint | Sk::Float => left_inner == right_inner, - Sk::Bool => false, - }, - _ => false, - }, - Bo::Multiply => { - let kind_match = match left_inner.scalar_kind() { - Some(Sk::Uint) | Some(Sk::Sint) | Some(Sk::Float) => true, - Some(Sk::Bool) | None => false, - }; - //TODO: should we be more restrictive here? I.e. expect scalar only to the left. - let types_match = match (left_inner, right_inner) { - (&Ti::Scalar { kind: kind1, .. }, &Ti::Scalar { kind: kind2, .. }) - | (&Ti::Vector { kind: kind1, .. }, &Ti::Scalar { kind: kind2, .. }) - | (&Ti::Scalar { kind: kind1, .. }, &Ti::Vector { kind: kind2, .. }) => { - kind1 == kind2 - } - ( - &Ti::Scalar { - kind: Sk::Float, .. - }, - &Ti::Matrix { .. }, - ) - | ( - &Ti::Matrix { .. }, - &Ti::Scalar { - kind: Sk::Float, .. - }, - ) => true, - ( - &Ti::Vector { - kind: kind1, - size: size1, - .. - }, - &Ti::Vector { - kind: kind2, - size: size2, - .. - }, - ) => kind1 == kind2 && size1 == size2, - ( - &Ti::Matrix { columns, .. }, - &Ti::Vector { - kind: Sk::Float, - size, - .. - }, - ) => columns == size, - ( - &Ti::Vector { - kind: Sk::Float, - size, - .. - }, - &Ti::Matrix { rows, .. }, - ) => size == rows, - (&Ti::Matrix { columns, .. }, &Ti::Matrix { rows, .. }) => { - columns == rows - } - _ => false, - }; - let left_width = match *left_inner { - Ti::Scalar { width, .. } - | Ti::Vector { width, .. } - | Ti::Matrix { width, .. } => width, - _ => 0, - }; - let right_width = match *right_inner { - Ti::Scalar { width, .. } - | Ti::Vector { width, .. } - | Ti::Matrix { width, .. } => width, - _ => 0, - }; - kind_match && types_match && left_width == right_width - } - Bo::Equal | Bo::NotEqual => left_inner.is_sized() && left_inner == right_inner, - Bo::Less | Bo::LessEqual | Bo::Greater | Bo::GreaterEqual => { - match *left_inner { - Ti::Scalar { kind, .. } | Ti::Vector { kind, .. } => match kind { - Sk::Uint | Sk::Sint | Sk::Float => left_inner == right_inner, - Sk::Bool => false, - }, - ref other => { - log::error!("Op {:?} left type {:?}", op, other); - false - } - } - } - Bo::LogicalAnd | Bo::LogicalOr => match *left_inner { - Ti::Scalar { kind: Sk::Bool, .. } | Ti::Vector { kind: Sk::Bool, .. } => { - left_inner == right_inner - } - ref other => { - log::error!("Op {:?} left type {:?}", op, other); - false - } - }, - Bo::And | Bo::ExclusiveOr | Bo::InclusiveOr => match *left_inner { - Ti::Scalar { kind, .. } | Ti::Vector { kind, .. } => match kind { - Sk::Sint | Sk::Uint => left_inner == right_inner, - Sk::Bool | Sk::Float => false, - }, - ref other => { - log::error!("Op {:?} left type {:?}", op, other); - false - } - }, - Bo::ShiftLeft | Bo::ShiftRight => { - let (base_size, base_kind) = match *left_inner { - Ti::Scalar { kind, .. } => (Ok(None), kind), - Ti::Vector { size, kind, .. } => (Ok(Some(size)), kind), - ref other => { - log::error!("Op {:?} base type {:?}", op, other); - (Err(()), Sk::Bool) - } - }; - let shift_size = match *right_inner { - Ti::Scalar { kind: Sk::Uint, .. } => Ok(None), - Ti::Vector { - size, - kind: Sk::Uint, - .. - } => Ok(Some(size)), - ref other => { - log::error!("Op {:?} shift type {:?}", op, other); - Err(()) - } - }; - match base_kind { - Sk::Sint | Sk::Uint => base_size.is_ok() && base_size == shift_size, - Sk::Float | Sk::Bool => false, - } - } - }; - if !good { - return Err(ExpressionError::InvalidBinaryOperandTypes(op, left, right)); - } - } - E::Select { - condition, - accept, - reject, - } => { - let accept_inner = resolver.resolve(accept)?; - let reject_inner = resolver.resolve(reject)?; - let condition_good = match *resolver.resolve(condition)? { - Ti::Scalar { - kind: Sk::Bool, - width: _, - } => accept_inner.is_sized(), - Ti::Vector { - size, - kind: Sk::Bool, - width: _, - } => match *accept_inner { - Ti::Vector { - size: other_size, .. - } => size == other_size, - _ => false, - }, - _ => false, - }; - if !condition_good || accept_inner != reject_inner { - return Err(ExpressionError::InvalidSelectTypes); - } - } - E::Derivative { axis, expr } => { - //TODO: check stage - } - E::Relational { fun, argument } => { - use crate::RelationalFunction as Rf; - let argument_inner = resolver.resolve(argument)?; - match fun { - Rf::All | Rf::Any => match *argument_inner { - Ti::Vector { kind: Sk::Bool, .. } => {} - ref other => { - log::error!("All/Any of type {:?}", other); - return Err(ExpressionError::InvalidBooleanVector(argument)); - } - }, - Rf::IsNan | Rf::IsInf | Rf::IsFinite | Rf::IsNormal => match *argument_inner { - Ti::Scalar { - kind: Sk::Float, .. - } - | Ti::Vector { - kind: Sk::Float, .. - } => {} - ref other => { - log::error!("Float test of type {:?}", other); - return Err(ExpressionError::InvalidFloatArgument(argument)); - } - }, - } - } - E::Math { - fun, - arg, - arg1, - arg2, - } => {} - E::As { - expr, - kind, - convert, - } => {} - E::Call(function) => {} - E::ArrayLength(expr) => match *resolver.resolve(expr)? { - Ti::Array { .. } => {} - ref other => { - log::error!("Array length of {:?}", other); - return Err(ExpressionError::InvalidArrayType(expr)); - } - }, - } - Ok(()) - } - - fn resolve_statement_type_impl<'a>( - &'a self, - handle: Handle, - types: &'a Arena, - ) -> Result<&'a crate::TypeInner, ExpressionError> { - if !self.valid_expression_set.contains(handle.index()) { - return Err(ExpressionError::NotInScope); - } - self.typifier - .try_get(handle, types) - .ok_or(ExpressionError::DoesntExist) - } - - fn resolve_statement_type<'a>( - &'a self, - handle: Handle, - types: &'a Arena, - ) -> Result<&'a crate::TypeInner, FunctionError> { - self.resolve_statement_type_impl(handle, types) - .map_err(|error| FunctionError::Expression { handle, error }) - } - - fn validate_block_impl( - &mut self, - statements: &[crate::Statement], - context: &BlockContext, - ) -> Result<(), FunctionError> { - use crate::{Statement as S, TypeInner as Ti}; - let mut finished = false; - for statement in statements { - if finished { - return Err(FunctionError::InstructionsAfterReturn); - } - match *statement { - S::Emit(ref range) => { - for handle in range.clone() { - if self.valid_expression_set.insert(handle.index()) { - self.valid_expression_list.push(handle); - } else { - return Err(FunctionError::ExpressionAlreadyInScope(handle)); - } - } - } - S::Block(ref block) => self.validate_block(block, context)?, - S::If { - condition, - ref accept, - ref reject, - } => { - match *self.resolve_statement_type(condition, context.types)? { - Ti::Scalar { - kind: crate::ScalarKind::Bool, - width: _, - } => {} - _ => return Err(FunctionError::InvalidIfType(condition)), - } - self.validate_block(accept, context)?; - self.validate_block(reject, context)?; - } - S::Switch { - selector, - ref cases, - ref default, - } => { - match *self.resolve_statement_type(selector, context.types)? { - Ti::Scalar { - kind: crate::ScalarKind::Sint, - width: _, - } => {} - _ => return Err(FunctionError::InvalidSwitchType(selector)), - } - self.select_cases.clear(); - for case in cases { - if !self.select_cases.insert(case.value) { - return Err(FunctionError::ConflictingSwitchCase(case.value)); - } - } - for case in cases { - self.validate_block(&case.body, context)?; - } - self.validate_block(default, context)?; - } - S::Loop { - ref body, - ref continuing, - } => { - // special handling for block scoping is needed here, - // because the continuing{} block inherits the scope - let base_expression_count = self.valid_expression_list.len(); - self.validate_block_impl( - body, - &context.with_flags(BlockFlags::CAN_JUMP | BlockFlags::IN_LOOP), - )?; - self.validate_block_impl(continuing, &context.with_flags(BlockFlags::empty()))?; - for handle in self.valid_expression_list.drain(base_expression_count..) { - self.valid_expression_set.remove(handle.index()); - } - } - S::Break | S::Continue => { - if !context.flags.contains(BlockFlags::IN_LOOP) { - return Err(FunctionError::BreakContinueOutsideOfLoop); - } - finished = true; - } - S::Return { value } => { - if !context.flags.contains(BlockFlags::CAN_JUMP) { - return Err(FunctionError::InvalidReturnSpot); - } - let value_ty = value - .map(|expr| self.resolve_statement_type(expr, context.types)) - .transpose()?; - let expected_ty = context.return_type.map(|ty| &context.types[ty].inner); - if value_ty != expected_ty { - log::error!( - "Returning {:?} where {:?} is expected", - value_ty, - expected_ty - ); - return Err(FunctionError::InvalidReturnType(value)); - } - finished = true; - } - S::Kill => { - finished = true; - } - S::Store { pointer, value } => { - let mut current = pointer; - loop { - self.typifier.try_get(current, context.types).ok_or( - FunctionError::Expression { - handle: current, - error: ExpressionError::DoesntExist, - }, - )?; - match context.expressions[current] { - crate::Expression::Access { base, .. } - | crate::Expression::AccessIndex { base, .. } => current = base, - crate::Expression::LocalVariable(_) - | crate::Expression::GlobalVariable(_) - | crate::Expression::FunctionArgument(_) => break, - _ => return Err(FunctionError::InvalidStorePointer(current)), - } - } - - let value_ty = self.resolve_statement_type(value, context.types)?; - match *value_ty { - Ti::Image { .. } | Ti::Sampler { .. } => { - return Err(FunctionError::InvalidStoreValue(value)); - } - _ => {} - } - let good = match self.typifier.try_get(pointer, context.types) { - Some(&Ti::Pointer { base, class: _ }) => { - *value_ty == context.types[base].inner - } - Some(&Ti::ValuePointer { - size: Some(size), - kind, - width, - class: _, - }) => *value_ty == Ti::Vector { size, kind, width }, - Some(&Ti::ValuePointer { - size: None, - kind, - width, - class: _, - }) => *value_ty == Ti::Scalar { kind, width }, - _ => false, - }; - if !good { - return Err(FunctionError::InvalidStoreTypes { pointer, value }); - } - } - S::ImageStore { - image, - coordinate: _, - array_index, - value, - } => { - let _expected_coordinate_ty = match *context.get_expression(image)? { - crate::Expression::GlobalVariable(_var_handle) => (), //TODO - _ => return Err(FunctionError::InvalidImage(image)), - }; - let value_ty = self.typifier.get(value, context.types); - match *value_ty { - Ti::Scalar { .. } | Ti::Vector { .. } => {} - _ => { - return Err(FunctionError::InvalidStoreValue(value)); - } - } - if let Some(expr) = array_index { - match *self.typifier.get(expr, context.types) { - Ti::Scalar { - kind: crate::ScalarKind::Sint, - width: _, - } => (), - _ => return Err(FunctionError::InvalidArrayIndex(expr)), - } - } - } - S::Call { - function, - ref arguments, - result, - } => { - if let Err(error) = self.validate_call(function, arguments, result, context) { - return Err(FunctionError::InvalidCall { function, error }); - } - } - } - } - Ok(()) - } - - fn validate_block( - &mut self, - statements: &[crate::Statement], - context: &BlockContext, - ) -> Result<(), FunctionError> { - let base_expression_count = self.valid_expression_list.len(); - self.validate_block_impl(statements, context)?; - for handle in self.valid_expression_list.drain(base_expression_count..) { - self.valid_expression_set.remove(handle.index()); - } - Ok(()) - } - - fn validate_function( - &mut self, - fun: &crate::Function, - _info: &FunctionInfo, - module: &crate::Module, - ) -> Result<(), FunctionError> { - let resolve_ctx = ResolveContext { - constants: &module.constants, - global_vars: &module.global_variables, - local_vars: &fun.local_variables, - functions: &module.functions, - arguments: &fun.arguments, - }; - self.typifier - .resolve_all(&fun.expressions, &module.types, &resolve_ctx)?; - - for (var_handle, var) in fun.local_variables.iter() { - self.validate_local_var(var, &module.types, &module.constants) - .map_err(|error| FunctionError::LocalVariable { - handle: var_handle, - name: var.name.clone().unwrap_or_default(), - error, - })?; - } - - for (index, argument) in fun.arguments.iter().enumerate() { - if !self.type_flags[argument.ty.index()].contains(TypeFlags::DATA) { - return Err(FunctionError::InvalidArgumentType { - index, - name: argument.name.clone().unwrap_or_default(), - }); - } - } - - self.valid_expression_set.clear(); - for (handle, expr) in fun.expressions.iter() { - if expr.needs_pre_emit() { - self.valid_expression_set.insert(handle.index()); - } - if let Err(error) = self.validate_expression(handle, expr, fun, module) { - return Err(FunctionError::Expression { handle, error }); - } - } - - self.validate_block( - &fun.body, - &BlockContext { - flags: BlockFlags::CAN_JUMP, - expressions: &fun.expressions, - types: &module.types, - functions: &module.functions, - return_type: fun.result.as_ref().map(|fr| fr.ty), - }, - ) - } - - fn validate_entry_point( - &mut self, - ep: &crate::EntryPoint, - info: &FunctionInfo, - module: &crate::Module, - ) -> Result<(), EntryPointError> { - if ep.early_depth_test.is_some() && ep.stage != crate::ShaderStage::Fragment { - return Err(EntryPointError::UnexpectedEarlyDepthTest); - } - if ep.stage == crate::ShaderStage::Compute { - if ep - .workgroup_size - .iter() - .any(|&s| s == 0 || s > MAX_WORKGROUP_SIZE) - { - return Err(EntryPointError::OutOfRangeWorkgroupSize); - } - } else if ep.workgroup_size != [0; 3] { - return Err(EntryPointError::UnexpectedWorkgroupSize); - } - - self.location_mask.clear(); - for (index, fa) in ep.function.arguments.iter().enumerate() { - let ctx = VaryingContext { - ty: fa.ty, - stage: ep.stage, - output: false, - types: &module.types, - location_mask: &mut self.location_mask, - }; - ctx.validate(fa.binding.as_ref()) - .map_err(|e| EntryPointError::Argument(index as u32, e))?; - } - - self.location_mask.clear(); - if let Some(ref fr) = ep.function.result { - let ctx = VaryingContext { - ty: fr.ty, - stage: ep.stage, - output: true, - types: &module.types, - location_mask: &mut self.location_mask, - }; - ctx.validate(fr.binding.as_ref()) - .map_err(EntryPointError::Result)?; - } - - for bg in self.bind_group_masks.iter_mut() { - bg.clear(); - } - for (var_handle, var) in module.global_variables.iter() { - let usage = info[var_handle]; - if usage.is_empty() { - continue; - } - - let allowed_usage = match var.class { - crate::StorageClass::Function => unreachable!(), - crate::StorageClass::Uniform => GlobalUse::READ | GlobalUse::QUERY, - crate::StorageClass::Storage => storage_usage(var.storage_access), - crate::StorageClass::Handle => match module.types[var.ty].inner { - crate::TypeInner::Image { - class: crate::ImageClass::Storage(_), - .. - } => storage_usage(var.storage_access), - _ => GlobalUse::READ | GlobalUse::QUERY, - }, - crate::StorageClass::Private | crate::StorageClass::WorkGroup => GlobalUse::all(), - crate::StorageClass::PushConstant => GlobalUse::READ, - }; - if !allowed_usage.contains(usage) { - log::warn!("\tUsage error for: {:?}", var); - log::warn!( - "\tAllowed usage: {:?}, requested: {:?}", - allowed_usage, - usage - ); - return Err(EntryPointError::InvalidGlobalUsage(var_handle, usage)); - } - - if let Some(ref bind) = var.binding { - while self.bind_group_masks.len() <= bind.group as usize { - self.bind_group_masks.push(BitSet::new()); - } - if !self.bind_group_masks[bind.group as usize].insert(bind.binding as usize) { - return Err(EntryPointError::BindingCollision(var_handle)); - } - } - } - - self.validate_function(&ep.function, info, module)?; - Ok(()) - } - - /// Check the given module to be valid. - pub fn validate(&mut self, module: &crate::Module) -> Result { - self.typifier.clear(); - self.type_flags.clear(); - self.type_flags - .resize(module.types.len(), TypeFlags::empty()); - - let analysis = Analysis::new(module, self.analysis_flags)?; - - for (handle, constant) in module.constants.iter() { - self.validate_constant(handle, &module.constants, &module.types) - .map_err(|error| ValidationError::Constant { - handle, - name: constant.name.clone().unwrap_or_default(), - error, - })?; - } - - // doing after the globals, so that `type_flags` is ready - for (handle, ty) in module.types.iter() { - let ty_flags = self - .validate_type(ty, handle, &module.constants) - .map_err(|error| ValidationError::Type { - handle, - name: ty.name.clone().unwrap_or_default(), - error, - })?; - self.type_flags[handle.index()] = ty_flags; - } - - for (var_handle, var) in module.global_variables.iter() { - self.validate_global_var(var, &module.types) - .map_err(|error| ValidationError::GlobalVariable { - handle: var_handle, - name: var.name.clone().unwrap_or_default(), - error, - })?; - } - - for (handle, fun) in module.functions.iter() { - self.validate_function(fun, &analysis[handle], module) - .map_err(|error| ValidationError::Function { - handle, - name: fun.name.clone().unwrap_or_default(), - error, - })?; - } - - let mut ep_map = FastHashSet::default(); - for (index, ep) in module.entry_points.iter().enumerate() { - if !ep_map.insert((ep.stage, &ep.name)) { - return Err(ValidationError::EntryPoint { - stage: ep.stage, - name: ep.name.clone(), - error: EntryPointError::Conflict, - }); - } - let info = analysis.get_entry_point(index); - self.validate_entry_point(ep, info, module) - .map_err(|error| ValidationError::EntryPoint { - stage: ep.stage, - name: ep.name.clone(), - error, - })?; - } - - Ok(analysis) - } -} diff --git a/src/proc/analyzer.rs b/src/valid/analyzer.rs similarity index 99% rename from src/proc/analyzer.rs rename to src/valid/analyzer.rs index c0b131a9ba..3264dfdf56 100644 --- a/src/proc/analyzer.rs +++ b/src/valid/analyzer.rs @@ -692,13 +692,13 @@ bitflags::bitflags! { #[cfg_attr(feature = "serialize", derive(serde::Serialize))] #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] -pub struct Analysis { +pub struct ModuleInfo { flags: AnalysisFlags, functions: Vec, entry_points: Vec, } -impl Analysis { +impl ModuleInfo { /// Builds the `FunctionInfo` based on the function, and validates the /// uniform control flow if required by the expressions of this function. fn process_function( @@ -732,9 +732,9 @@ impl Analysis { Ok(info) } - /// Analyze a module and return the `Analysis`, if successful. + /// Analyze a module and return the `ModuleInfo`, if successful. pub fn new(module: &crate::Module, flags: AnalysisFlags) -> Result { - let mut this = Analysis { + let mut this = ModuleInfo { flags, functions: Vec::with_capacity(module.functions.len()), entry_points: Vec::with_capacity(module.entry_points.len()), @@ -761,7 +761,7 @@ impl Analysis { } } -impl ops::Index> for Analysis { +impl ops::Index> for ModuleInfo { type Output = FunctionInfo; fn index(&self, handle: Handle) -> &FunctionInfo { &self.functions[handle.index()] diff --git a/src/valid/expression.rs b/src/valid/expression.rs new file mode 100644 index 0000000000..119d1f4d82 --- /dev/null +++ b/src/valid/expression.rs @@ -0,0 +1,549 @@ +use crate::{ + arena::{Arena, Handle}, + proc::Typifier, +}; + +#[derive(Clone, Debug, thiserror::Error)] +pub enum ExpressionError { + #[error("Doesn't exist")] + DoesntExist, + #[error("Used by a statement before it was introduced into the scope by any of the dominating blocks")] + NotInScope, + #[error("Depends on {0:?}, which has not been processed yet")] + ForwardDependency(Handle), + #[error("Base type {0:?} is not compatible with this expression")] + InvalidBaseType(Handle), + #[error("Accessing with index {0:?} can't be done")] + InvalidIndexType(Handle), + #[error("Accessing index {1} is out of {0:?} bounds")] + IndexOutOfBounds(Handle, u32), + #[error("Function argument {0:?} doesn't exist")] + FunctionArgumentDoesntExist(u32), + #[error("Constant {0:?} doesn't exist")] + ConstantDoesntExist(Handle), + #[error("Global variable {0:?} doesn't exist")] + GlobalVarDoesntExist(Handle), + #[error("Local variable {0:?} doesn't exist")] + LocalVarDoesntExist(Handle), + #[error("Loading of {0:?} can't be done")] + InvalidPointerType(Handle), + #[error("Array length of {0:?} can't be done")] + InvalidArrayType(Handle), + #[error("Compose type {0:?} doesn't exist")] + ComposeTypeDoesntExist(Handle), + #[error("Composing of type {0:?} can't be done")] + InvalidComposeType(Handle), + #[error("Composing expects {expected} components but {given} were given")] + InvalidComposeCount { given: u32, expected: u32 }, + #[error("Composing {0}'s component {1:?} is not expected")] + InvalidComponentType(u32, Handle), + #[error("Operation {0:?} can't work with {1:?}")] + InvalidUnaryOperandType(crate::UnaryOperator, Handle), + #[error("Operation {0:?} can't work with {1:?} and {2:?}")] + InvalidBinaryOperandTypes( + crate::BinaryOperator, + Handle, + Handle, + ), + #[error("Selecting is not possible")] + InvalidSelectTypes, + #[error("Relational argument {0:?} is not a boolean vector")] + InvalidBooleanVector(Handle), + #[error("Relational argument {0:?} is not a float")] + InvalidFloatArgument(Handle), +} + +struct ExpressionTypeResolver<'a> { + root: Handle, + types: &'a Arena, + typifier: &'a Typifier, +} + +impl<'a> ExpressionTypeResolver<'a> { + fn resolve( + &self, + handle: Handle, + ) -> Result<&'a crate::TypeInner, ExpressionError> { + if handle < self.root { + Ok(self.typifier.get(handle, self.types)) + } else { + Err(ExpressionError::ForwardDependency(handle)) + } + } +} + +impl super::Validator { + pub(super) fn validate_expression( + &self, + root: Handle, + expression: &crate::Expression, + function: &crate::Function, + module: &crate::Module, + ) -> Result<(), ExpressionError> { + use crate::{Expression as E, ScalarKind as Sk, TypeInner as Ti}; + + let resolver = ExpressionTypeResolver { + root, + types: &module.types, + typifier: &self.typifier, + }; + + match *expression { + E::Access { base, index } => { + match *resolver.resolve(base)? { + Ti::Vector { .. } + | Ti::Matrix { .. } + | Ti::Array { .. } + | Ti::Pointer { .. } => {} + ref other => { + log::error!("Indexing of {:?}", other); + return Err(ExpressionError::InvalidBaseType(base)); + } + } + match *resolver.resolve(index)? { + //TODO: only allow one of these + Ti::Scalar { + kind: Sk::Sint, + width: _, + } + | Ti::Scalar { + kind: Sk::Uint, + width: _, + } => {} + ref other => { + log::error!("Indexing by {:?}", other); + return Err(ExpressionError::InvalidIndexType(index)); + } + } + } + E::AccessIndex { base, index } => { + let limit = match *resolver.resolve(base)? { + Ti::Vector { size, .. } => size as u32, + Ti::Matrix { columns, .. } => columns as u32, + Ti::Array { + size: crate::ArraySize::Constant(handle), + .. + } => module.constants[handle].to_array_length().unwrap(), + Ti::Array { .. } => !0, // can't statically know, but need run-time checks + Ti::Pointer { .. } => !0, //TODO + Ti::Struct { + ref members, + block: _, + } => members.len() as u32, + ref other => { + log::error!("Indexing of {:?}", other); + return Err(ExpressionError::InvalidBaseType(base)); + } + }; + if index >= limit { + return Err(ExpressionError::IndexOutOfBounds(base, index)); + } + } + E::Constant(handle) => { + let _ = module + .constants + .try_get(handle) + .ok_or(ExpressionError::ConstantDoesntExist(handle))?; + } + E::Compose { ref components, ty } => { + match module + .types + .try_get(ty) + .ok_or(ExpressionError::ComposeTypeDoesntExist(ty))? + .inner + { + // vectors are composed from scalars or other vectors + Ti::Vector { size, kind, width } => { + let mut total = 0; + for (index, &comp) in components.iter().enumerate() { + total += match *resolver.resolve(comp)? { + Ti::Scalar { + kind: comp_kind, + width: comp_width, + } if comp_kind == kind && comp_width == width => 1, + Ti::Vector { + size: comp_size, + kind: comp_kind, + width: comp_width, + } if comp_kind == kind && comp_width == width => comp_size as u32, + ref other => { + log::error!("Vector component[{}] type {:?}", index, other); + return Err(ExpressionError::InvalidComponentType( + index as u32, + comp, + )); + } + }; + } + if size as u32 != total { + return Err(ExpressionError::InvalidComposeCount { + expected: size as u32, + given: total, + }); + } + } + // matrix are composed from column vectors + Ti::Matrix { + columns, + rows, + width, + } => { + let inner = Ti::Vector { + size: rows, + kind: Sk::Float, + width, + }; + if columns as usize != components.len() { + return Err(ExpressionError::InvalidComposeCount { + expected: columns as u32, + given: components.len() as u32, + }); + } + for (index, &comp) in components.iter().enumerate() { + let tin = resolver.resolve(comp)?; + if tin != &inner { + log::error!("Matrix component[{}] type {:?}", index, tin); + return Err(ExpressionError::InvalidComponentType( + index as u32, + comp, + )); + } + } + } + Ti::Array { + base, + size: crate::ArraySize::Constant(handle), + stride: _, + } => { + let count = module.constants[handle].to_array_length().unwrap(); + if count as usize != components.len() { + return Err(ExpressionError::InvalidComposeCount { + expected: count, + given: components.len() as u32, + }); + } + let base_inner = &module.types[base].inner; + for (index, &comp) in components.iter().enumerate() { + let tin = resolver.resolve(comp)?; + if tin != base_inner { + log::error!("Array component[{}] type {:?}", index, tin); + return Err(ExpressionError::InvalidComponentType( + index as u32, + comp, + )); + } + } + } + Ti::Struct { + block: _, + ref members, + } => { + for (index, (member, &comp)) in members.iter().zip(components).enumerate() { + let tin = resolver.resolve(comp)?; + if tin != &module.types[member.ty].inner { + log::error!("Struct component[{}] type {:?}", index, tin); + return Err(ExpressionError::InvalidComponentType( + index as u32, + comp, + )); + } + } + if members.len() != components.len() { + return Err(ExpressionError::InvalidComposeCount { + given: components.len() as u32, + expected: members.len() as u32, + }); + } + } + ref other => { + log::error!("Composing of {:?}", other); + return Err(ExpressionError::InvalidComposeType(ty)); + } + } + } + E::FunctionArgument(index) => { + if index >= function.arguments.len() as u32 { + return Err(ExpressionError::FunctionArgumentDoesntExist(index)); + } + } + E::GlobalVariable(handle) => { + let _ = module + .global_variables + .try_get(handle) + .ok_or(ExpressionError::GlobalVarDoesntExist(handle))?; + } + E::LocalVariable(handle) => { + let _ = function + .local_variables + .try_get(handle) + .ok_or(ExpressionError::LocalVarDoesntExist(handle))?; + } + E::Load { pointer } => match *resolver.resolve(pointer)? { + Ti::Pointer { .. } | Ti::ValuePointer { .. } => {} + ref other => { + log::error!("Loading {:?}", other); + return Err(ExpressionError::InvalidPointerType(pointer)); + } + }, + #[allow(unused)] + E::ImageSample { + image, + sampler, + coordinate, + array_index, + offset, + level, + depth_ref, + } => {} + #[allow(unused)] + E::ImageLoad { + image, + coordinate, + array_index, + index, + } => {} + #[allow(unused)] + E::ImageQuery { image, query } => {} + E::Unary { op, expr } => { + use crate::UnaryOperator as Uo; + let inner = resolver.resolve(expr)?; + match (op, inner.scalar_kind()) { + (_, Some(Sk::Sint)) + | (_, Some(Sk::Bool)) + | (Uo::Negate, Some(Sk::Float)) + | (Uo::Not, Some(Sk::Uint)) => {} + other => { + log::error!("Op {:?} kind {:?}", op, other); + return Err(ExpressionError::InvalidUnaryOperandType(op, expr)); + } + } + } + E::Binary { op, left, right } => { + use crate::BinaryOperator as Bo; + let left_inner = resolver.resolve(left)?; + let right_inner = resolver.resolve(right)?; + let good = match op { + Bo::Add | Bo::Subtract | Bo::Divide | Bo::Modulo => match *left_inner { + Ti::Scalar { kind, .. } | Ti::Vector { kind, .. } => match kind { + Sk::Uint | Sk::Sint | Sk::Float => left_inner == right_inner, + Sk::Bool => false, + }, + _ => false, + }, + Bo::Multiply => { + let kind_match = match left_inner.scalar_kind() { + Some(Sk::Uint) | Some(Sk::Sint) | Some(Sk::Float) => true, + Some(Sk::Bool) | None => false, + }; + //TODO: should we be more restrictive here? I.e. expect scalar only to the left. + let types_match = match (left_inner, right_inner) { + (&Ti::Scalar { kind: kind1, .. }, &Ti::Scalar { kind: kind2, .. }) + | (&Ti::Vector { kind: kind1, .. }, &Ti::Scalar { kind: kind2, .. }) + | (&Ti::Scalar { kind: kind1, .. }, &Ti::Vector { kind: kind2, .. }) => { + kind1 == kind2 + } + ( + &Ti::Scalar { + kind: Sk::Float, .. + }, + &Ti::Matrix { .. }, + ) + | ( + &Ti::Matrix { .. }, + &Ti::Scalar { + kind: Sk::Float, .. + }, + ) => true, + ( + &Ti::Vector { + kind: kind1, + size: size1, + .. + }, + &Ti::Vector { + kind: kind2, + size: size2, + .. + }, + ) => kind1 == kind2 && size1 == size2, + ( + &Ti::Matrix { columns, .. }, + &Ti::Vector { + kind: Sk::Float, + size, + .. + }, + ) => columns == size, + ( + &Ti::Vector { + kind: Sk::Float, + size, + .. + }, + &Ti::Matrix { rows, .. }, + ) => size == rows, + (&Ti::Matrix { columns, .. }, &Ti::Matrix { rows, .. }) => { + columns == rows + } + _ => false, + }; + let left_width = match *left_inner { + Ti::Scalar { width, .. } + | Ti::Vector { width, .. } + | Ti::Matrix { width, .. } => width, + _ => 0, + }; + let right_width = match *right_inner { + Ti::Scalar { width, .. } + | Ti::Vector { width, .. } + | Ti::Matrix { width, .. } => width, + _ => 0, + }; + kind_match && types_match && left_width == right_width + } + Bo::Equal | Bo::NotEqual => left_inner.is_sized() && left_inner == right_inner, + Bo::Less | Bo::LessEqual | Bo::Greater | Bo::GreaterEqual => { + match *left_inner { + Ti::Scalar { kind, .. } | Ti::Vector { kind, .. } => match kind { + Sk::Uint | Sk::Sint | Sk::Float => left_inner == right_inner, + Sk::Bool => false, + }, + ref other => { + log::error!("Op {:?} left type {:?}", op, other); + false + } + } + } + Bo::LogicalAnd | Bo::LogicalOr => match *left_inner { + Ti::Scalar { kind: Sk::Bool, .. } | Ti::Vector { kind: Sk::Bool, .. } => { + left_inner == right_inner + } + ref other => { + log::error!("Op {:?} left type {:?}", op, other); + false + } + }, + Bo::And | Bo::ExclusiveOr | Bo::InclusiveOr => match *left_inner { + Ti::Scalar { kind, .. } | Ti::Vector { kind, .. } => match kind { + Sk::Sint | Sk::Uint => left_inner == right_inner, + Sk::Bool | Sk::Float => false, + }, + ref other => { + log::error!("Op {:?} left type {:?}", op, other); + false + } + }, + Bo::ShiftLeft | Bo::ShiftRight => { + let (base_size, base_kind) = match *left_inner { + Ti::Scalar { kind, .. } => (Ok(None), kind), + Ti::Vector { size, kind, .. } => (Ok(Some(size)), kind), + ref other => { + log::error!("Op {:?} base type {:?}", op, other); + (Err(()), Sk::Bool) + } + }; + let shift_size = match *right_inner { + Ti::Scalar { kind: Sk::Uint, .. } => Ok(None), + Ti::Vector { + size, + kind: Sk::Uint, + .. + } => Ok(Some(size)), + ref other => { + log::error!("Op {:?} shift type {:?}", op, other); + Err(()) + } + }; + match base_kind { + Sk::Sint | Sk::Uint => base_size.is_ok() && base_size == shift_size, + Sk::Float | Sk::Bool => false, + } + } + }; + if !good { + return Err(ExpressionError::InvalidBinaryOperandTypes(op, left, right)); + } + } + E::Select { + condition, + accept, + reject, + } => { + let accept_inner = resolver.resolve(accept)?; + let reject_inner = resolver.resolve(reject)?; + let condition_good = match *resolver.resolve(condition)? { + Ti::Scalar { + kind: Sk::Bool, + width: _, + } => accept_inner.is_sized(), + Ti::Vector { + size, + kind: Sk::Bool, + width: _, + } => match *accept_inner { + Ti::Vector { + size: other_size, .. + } => size == other_size, + _ => false, + }, + _ => false, + }; + if !condition_good || accept_inner != reject_inner { + return Err(ExpressionError::InvalidSelectTypes); + } + } + #[allow(unused)] + E::Derivative { axis, expr } => { + //TODO: check stage + } + E::Relational { fun, argument } => { + use crate::RelationalFunction as Rf; + let argument_inner = resolver.resolve(argument)?; + match fun { + Rf::All | Rf::Any => match *argument_inner { + Ti::Vector { kind: Sk::Bool, .. } => {} + ref other => { + log::error!("All/Any of type {:?}", other); + return Err(ExpressionError::InvalidBooleanVector(argument)); + } + }, + Rf::IsNan | Rf::IsInf | Rf::IsFinite | Rf::IsNormal => match *argument_inner { + Ti::Scalar { + kind: Sk::Float, .. + } + | Ti::Vector { + kind: Sk::Float, .. + } => {} + ref other => { + log::error!("Float test of type {:?}", other); + return Err(ExpressionError::InvalidFloatArgument(argument)); + } + }, + } + } + #[allow(unused)] + E::Math { + fun, + arg, + arg1, + arg2, + } => {} + #[allow(unused)] + E::As { + expr, + kind, + convert, + } => {} + #[allow(unused)] + E::Call(function) => {} + E::ArrayLength(expr) => match *resolver.resolve(expr)? { + Ti::Array { .. } => {} + ref other => { + log::error!("Array length of {:?}", other); + return Err(ExpressionError::InvalidArrayType(expr)); + } + }, + } + Ok(()) + } +} diff --git a/src/valid/function.rs b/src/valid/function.rs new file mode 100644 index 0000000000..2edf339c70 --- /dev/null +++ b/src/valid/function.rs @@ -0,0 +1,510 @@ +use super::{analyzer::FunctionInfo, ExpressionError, TypeFlags}; +use crate::{ + arena::{Arena, Handle}, + proc::{ResolveContext, TypifyError}, +}; + +#[derive(Clone, Debug, thiserror::Error)] +pub enum CallError { + #[error("Bad function")] + InvalidFunction, + #[error("The callee is declared after the caller")] + ForwardDeclaredFunction, + #[error("Argument {index} expression is invalid")] + Argument { + index: usize, + #[source] + error: ExpressionError, + }, + #[error("Result expression {0:?} has already been introduced earlier")] + ResultAlreadyInScope(Handle), + #[error("Result value is invalid")] + ResultValue(#[source] ExpressionError), + #[error("Requires {required} arguments, but {seen} are provided")] + ArgumentCount { required: usize, seen: usize }, + #[error("Argument {index} value {seen_expression:?} doesn't match the type {required:?}")] + ArgumentType { + index: usize, + required: Handle, + seen_expression: Handle, + }, + #[error("Result value {seen_expression:?} does not match the type {required:?}")] + ResultType { + required: Option>, + seen_expression: Option>, + }, +} + +#[derive(Clone, Debug, thiserror::Error)] +pub enum LocalVariableError { + #[error("Initializer doesn't match the variable type")] + InitializerType, +} + +#[derive(Clone, Debug, thiserror::Error)] +pub enum FunctionError { + #[error(transparent)] + Resolve(#[from] TypifyError), + #[error("Expression {handle:?} is invalid")] + Expression { + handle: Handle, + #[source] + error: ExpressionError, + }, + #[error("Expression {0:?} can't be introduced - it's already in scope")] + ExpressionAlreadyInScope(Handle), + #[error("Local variable {handle:?} '{name}' is invalid")] + LocalVariable { + handle: Handle, + name: String, + #[source] + error: LocalVariableError, + }, + #[error("Argument '{name}' at index {index} has a type that can't be passed into functions.")] + InvalidArgumentType { index: usize, name: String }, + #[error("There are instructions after `return`/`break`/`continue`")] + InstructionsAfterReturn, + #[error("The `break`/`continue` is used outside of a loop context")] + BreakContinueOutsideOfLoop, + #[error("The `return` is called within a `continuing` block")] + InvalidReturnSpot, + #[error("The `return` value {0:?} does not match the function return value")] + InvalidReturnType(Option>), + #[error("The `if` condition {0:?} is not a boolean scalar")] + InvalidIfType(Handle), + #[error("The `switch` value {0:?} is not an integer scalar")] + InvalidSwitchType(Handle), + #[error("Multiple `switch` cases for {0} are present")] + ConflictingSwitchCase(i32), + #[error("The pointer {0:?} doesn't relate to a valid destination for a store")] + InvalidStorePointer(Handle), + #[error("The value {0:?} can not be stored")] + InvalidStoreValue(Handle), + #[error("Store of {value:?} into {pointer:?} doesn't have matching types")] + InvalidStoreTypes { + pointer: Handle, + value: Handle, + }, + #[error("The image array can't be indexed by {0:?}")] + InvalidArrayIndex(Handle), + #[error("The expression {0:?} is currupted")] + InvalidExpression(Handle), + #[error("The expression {0:?} is not an image")] + InvalidImage(Handle), + #[error("Call to {function:?} is invalid")] + InvalidCall { + function: Handle, + #[source] + error: CallError, + }, +} + +bitflags::bitflags! { + #[repr(transparent)] + struct Flags: u8 { + /// The control can jump out of this block. + const CAN_JUMP = 0x1; + /// The control is in a loop, can break and continue. + const IN_LOOP = 0x2; + } +} + +struct BlockContext<'a> { + flags: Flags, + expressions: &'a Arena, + types: &'a Arena, + functions: &'a Arena, + return_type: Option>, +} + +impl<'a> BlockContext<'a> { + pub(super) fn new(fun: &'a crate::Function, module: &'a crate::Module) -> Self { + Self { + flags: Flags::CAN_JUMP, + expressions: &fun.expressions, + types: &module.types, + functions: &module.functions, + return_type: fun.result.as_ref().map(|fr| fr.ty), + } + } + + fn with_flags(&self, flags: Flags) -> Self { + BlockContext { + flags, + expressions: self.expressions, + types: self.types, + functions: self.functions, + return_type: self.return_type, + } + } + + fn get_expression( + &self, + handle: Handle, + ) -> Result<&'a crate::Expression, FunctionError> { + self.expressions + .try_get(handle) + .ok_or(FunctionError::InvalidExpression(handle)) + } +} + +impl super::Validator { + fn validate_call( + &mut self, + function: Handle, + arguments: &[Handle], + result: Option>, + context: &BlockContext, + ) -> Result<(), CallError> { + let fun = context + .functions + .try_get(function) + .ok_or(CallError::InvalidFunction)?; + if fun.arguments.len() != arguments.len() { + return Err(CallError::ArgumentCount { + required: fun.arguments.len(), + seen: arguments.len(), + }); + } + for (index, (arg, &expr)) in fun.arguments.iter().zip(arguments).enumerate() { + let ty = self + .resolve_statement_type_impl(expr, context.types) + .map_err(|error| CallError::Argument { index, error })?; + if ty != &context.types[arg.ty].inner { + return Err(CallError::ArgumentType { + index, + required: arg.ty, + seen_expression: expr, + }); + } + } + + if let Some(expr) = result { + if self.valid_expression_set.insert(expr.index()) { + self.valid_expression_list.push(expr); + } else { + return Err(CallError::ResultAlreadyInScope(expr)); + } + } + + let result_ty = result + .map(|expr| self.resolve_statement_type_impl(expr, context.types)) + .transpose() + .map_err(CallError::ResultValue)?; + let expected_ty = fun.result.as_ref().map(|fr| &context.types[fr.ty].inner); + if result_ty != expected_ty { + log::error!( + "Called function returns {:?} where {:?} is expected", + result_ty, + expected_ty + ); + return Err(CallError::ResultType { + required: fun.result.as_ref().map(|fr| fr.ty), + seen_expression: result, + }); + } + Ok(()) + } + + fn resolve_statement_type_impl<'a>( + &'a self, + handle: Handle, + types: &'a Arena, + ) -> Result<&'a crate::TypeInner, ExpressionError> { + if !self.valid_expression_set.contains(handle.index()) { + return Err(ExpressionError::NotInScope); + } + self.typifier + .try_get(handle, types) + .ok_or(ExpressionError::DoesntExist) + } + + fn resolve_statement_type<'a>( + &'a self, + handle: Handle, + types: &'a Arena, + ) -> Result<&'a crate::TypeInner, FunctionError> { + self.resolve_statement_type_impl(handle, types) + .map_err(|error| FunctionError::Expression { handle, error }) + } + + fn validate_block_impl( + &mut self, + statements: &[crate::Statement], + context: &BlockContext, + ) -> Result<(), FunctionError> { + use crate::{Statement as S, TypeInner as Ti}; + let mut finished = false; + for statement in statements { + if finished { + return Err(FunctionError::InstructionsAfterReturn); + } + match *statement { + S::Emit(ref range) => { + for handle in range.clone() { + if self.valid_expression_set.insert(handle.index()) { + self.valid_expression_list.push(handle); + } else { + return Err(FunctionError::ExpressionAlreadyInScope(handle)); + } + } + } + S::Block(ref block) => self.validate_block(block, context)?, + S::If { + condition, + ref accept, + ref reject, + } => { + match *self.resolve_statement_type(condition, context.types)? { + Ti::Scalar { + kind: crate::ScalarKind::Bool, + width: _, + } => {} + _ => return Err(FunctionError::InvalidIfType(condition)), + } + self.validate_block(accept, context)?; + self.validate_block(reject, context)?; + } + S::Switch { + selector, + ref cases, + ref default, + } => { + match *self.resolve_statement_type(selector, context.types)? { + Ti::Scalar { + kind: crate::ScalarKind::Sint, + width: _, + } => {} + _ => return Err(FunctionError::InvalidSwitchType(selector)), + } + self.select_cases.clear(); + for case in cases { + if !self.select_cases.insert(case.value) { + return Err(FunctionError::ConflictingSwitchCase(case.value)); + } + } + for case in cases { + self.validate_block(&case.body, context)?; + } + self.validate_block(default, context)?; + } + S::Loop { + ref body, + ref continuing, + } => { + // special handling for block scoping is needed here, + // because the continuing{} block inherits the scope + let base_expression_count = self.valid_expression_list.len(); + self.validate_block_impl( + body, + &context.with_flags(Flags::CAN_JUMP | Flags::IN_LOOP), + )?; + self.validate_block_impl(continuing, &context.with_flags(Flags::empty()))?; + for handle in self.valid_expression_list.drain(base_expression_count..) { + self.valid_expression_set.remove(handle.index()); + } + } + S::Break | S::Continue => { + if !context.flags.contains(Flags::IN_LOOP) { + return Err(FunctionError::BreakContinueOutsideOfLoop); + } + finished = true; + } + S::Return { value } => { + if !context.flags.contains(Flags::CAN_JUMP) { + return Err(FunctionError::InvalidReturnSpot); + } + let value_ty = value + .map(|expr| self.resolve_statement_type(expr, context.types)) + .transpose()?; + let expected_ty = context.return_type.map(|ty| &context.types[ty].inner); + if value_ty != expected_ty { + log::error!( + "Returning {:?} where {:?} is expected", + value_ty, + expected_ty + ); + return Err(FunctionError::InvalidReturnType(value)); + } + finished = true; + } + S::Kill => { + finished = true; + } + S::Store { pointer, value } => { + let mut current = pointer; + loop { + self.typifier.try_get(current, context.types).ok_or( + FunctionError::Expression { + handle: current, + error: ExpressionError::DoesntExist, + }, + )?; + match context.expressions[current] { + crate::Expression::Access { base, .. } + | crate::Expression::AccessIndex { base, .. } => current = base, + crate::Expression::LocalVariable(_) + | crate::Expression::GlobalVariable(_) + | crate::Expression::FunctionArgument(_) => break, + _ => return Err(FunctionError::InvalidStorePointer(current)), + } + } + + let value_ty = self.resolve_statement_type(value, context.types)?; + match *value_ty { + Ti::Image { .. } | Ti::Sampler { .. } => { + return Err(FunctionError::InvalidStoreValue(value)); + } + _ => {} + } + let good = match self.typifier.try_get(pointer, context.types) { + Some(&Ti::Pointer { base, class: _ }) => { + *value_ty == context.types[base].inner + } + Some(&Ti::ValuePointer { + size: Some(size), + kind, + width, + class: _, + }) => *value_ty == Ti::Vector { size, kind, width }, + Some(&Ti::ValuePointer { + size: None, + kind, + width, + class: _, + }) => *value_ty == Ti::Scalar { kind, width }, + _ => false, + }; + if !good { + return Err(FunctionError::InvalidStoreTypes { pointer, value }); + } + } + S::ImageStore { + image, + coordinate: _, + array_index, + value, + } => { + let _expected_coordinate_ty = match *context.get_expression(image)? { + crate::Expression::GlobalVariable(_var_handle) => (), //TODO + _ => return Err(FunctionError::InvalidImage(image)), + }; + let value_ty = self.typifier.get(value, context.types); + match *value_ty { + Ti::Scalar { .. } | Ti::Vector { .. } => {} + _ => { + return Err(FunctionError::InvalidStoreValue(value)); + } + } + if let Some(expr) = array_index { + match *self.typifier.get(expr, context.types) { + Ti::Scalar { + kind: crate::ScalarKind::Sint, + width: _, + } => (), + _ => return Err(FunctionError::InvalidArrayIndex(expr)), + } + } + } + S::Call { + function, + ref arguments, + result, + } => { + if let Err(error) = self.validate_call(function, arguments, result, context) { + return Err(FunctionError::InvalidCall { function, error }); + } + } + } + } + Ok(()) + } + + fn validate_block( + &mut self, + statements: &[crate::Statement], + context: &BlockContext, + ) -> Result<(), FunctionError> { + let base_expression_count = self.valid_expression_list.len(); + self.validate_block_impl(statements, context)?; + for handle in self.valid_expression_list.drain(base_expression_count..) { + self.valid_expression_set.remove(handle.index()); + } + Ok(()) + } + + fn validate_local_var( + &self, + var: &crate::LocalVariable, + types: &Arena, + constants: &Arena, + ) -> Result<(), LocalVariableError> { + log::debug!("var {:?}", var); + if let Some(const_handle) = var.init { + match constants[const_handle].inner { + crate::ConstantInner::Scalar { width, ref value } => { + let ty_inner = crate::TypeInner::Scalar { + width, + kind: value.scalar_kind(), + }; + if types[var.ty].inner != ty_inner { + return Err(LocalVariableError::InitializerType); + } + } + crate::ConstantInner::Composite { ty, components: _ } => { + if ty != var.ty { + return Err(LocalVariableError::InitializerType); + } + } + } + } + Ok(()) + } + + pub(super) fn validate_function( + &mut self, + fun: &crate::Function, + _info: &FunctionInfo, + module: &crate::Module, + ) -> Result<(), FunctionError> { + let resolve_ctx = ResolveContext { + constants: &module.constants, + global_vars: &module.global_variables, + local_vars: &fun.local_variables, + functions: &module.functions, + arguments: &fun.arguments, + }; + self.typifier + .resolve_all(&fun.expressions, &module.types, &resolve_ctx)?; + + for (var_handle, var) in fun.local_variables.iter() { + self.validate_local_var(var, &module.types, &module.constants) + .map_err(|error| FunctionError::LocalVariable { + handle: var_handle, + name: var.name.clone().unwrap_or_default(), + error, + })?; + } + + for (index, argument) in fun.arguments.iter().enumerate() { + if !self.type_flags[argument.ty.index()].contains(TypeFlags::DATA) { + return Err(FunctionError::InvalidArgumentType { + index, + name: argument.name.clone().unwrap_or_default(), + }); + } + } + + self.valid_expression_set.clear(); + for (handle, expr) in fun.expressions.iter() { + if expr.needs_pre_emit() { + self.valid_expression_set.insert(handle.index()); + } + if let Err(error) = self.validate_expression(handle, expr, fun, module) { + return Err(FunctionError::Expression { handle, error }); + } + } + + self.validate_block(&fun.body, &BlockContext::new(fun, module)) + } +} diff --git a/src/valid/interface.rs b/src/valid/interface.rs new file mode 100644 index 0000000000..39d5ebdef8 --- /dev/null +++ b/src/valid/interface.rs @@ -0,0 +1,428 @@ +use super::{ + analyzer::{FunctionInfo, GlobalUse}, + FunctionError, TypeFlags, +}; +use crate::arena::{Arena, Handle}; + +use bit_set::BitSet; + +const MAX_WORKGROUP_SIZE: u32 = 0x4000; + +#[derive(Clone, Debug, thiserror::Error)] +pub enum GlobalVariableError { + #[error("Usage isn't compatible with the storage class")] + InvalidUsage, + #[error("Type isn't compatible with the storage class")] + InvalidType, + #[error("Storage access {seen:?} exceeds the allowed {allowed:?}")] + InvalidStorageAccess { + allowed: crate::StorageAccess, + seen: crate::StorageAccess, + }, + #[error("Type flags {seen:?} do not meet the required {required:?}")] + MissingTypeFlags { + required: TypeFlags, + seen: TypeFlags, + }, + #[error("Binding decoration is missing or not applicable")] + InvalidBinding, +} + +#[derive(Clone, Debug, thiserror::Error)] +pub enum VaryingError { + #[error("The type {0:?} does not match the varying")] + InvalidType(Handle), + #[error("Interpolation is not valid")] + InvalidInterpolation, + #[error("BuiltIn {0:?} is not available at this stage")] + InvalidBuiltInStage(crate::BuiltIn), + #[error("BuiltIn type for {0:?} is invalid")] + InvalidBuiltInType(crate::BuiltIn), + #[error("Struct member {0} is missing a binding")] + MemberMissingBinding(u32), + #[error("Multiple bindings at location {location} are present")] + BindingCollision { location: u32 }, +} + +#[derive(Clone, Debug, thiserror::Error)] +pub enum EntryPointError { + #[error("Multiple conflicting entry points")] + Conflict, + #[error("Early depth test is not applicable")] + UnexpectedEarlyDepthTest, + #[error("Workgroup size is not applicable")] + UnexpectedWorkgroupSize, + #[error("Workgroup size is out of range")] + OutOfRangeWorkgroupSize, + #[error("Global variable {0:?} is used incorrectly as {1:?}")] + InvalidGlobalUsage(Handle, GlobalUse), + #[error("Bindings for {0:?} conflict with other resource")] + BindingCollision(Handle), + #[error("Argument {0} varying error")] + Argument(u32, #[source] VaryingError), + #[error("Result varying error")] + Result(#[source] VaryingError), + #[error("Location {location} onterpolation of an integer has to be flat")] + InvalidIntegerInterpolation { location: u32 }, + #[error(transparent)] + Function(#[from] FunctionError), +} + +fn storage_usage(access: crate::StorageAccess) -> GlobalUse { + let mut storage_usage = GlobalUse::QUERY; + if access.contains(crate::StorageAccess::LOAD) { + storage_usage |= GlobalUse::READ; + } + if access.contains(crate::StorageAccess::STORE) { + storage_usage |= GlobalUse::WRITE; + } + storage_usage +} + +struct VaryingContext<'a> { + ty: Handle, + stage: crate::ShaderStage, + output: bool, + types: &'a Arena, + location_mask: &'a mut BitSet, +} + +impl VaryingContext<'_> { + fn validate_impl(&mut self, binding: &crate::Binding) -> Result<(), VaryingError> { + use crate::{ + BuiltIn as Bi, ScalarKind as Sk, ShaderStage as St, TypeInner as Ti, VectorSize as Vs, + }; + + let ty_inner = &self.types[self.ty].inner; + match *binding { + crate::Binding::BuiltIn(built_in) => { + let width = 4; + let (visible, type_good) = match built_in { + Bi::BaseInstance | Bi::BaseVertex | Bi::InstanceIndex | Bi::VertexIndex => ( + self.stage == St::Vertex && !self.output, + *ty_inner + == Ti::Scalar { + kind: Sk::Uint, + width, + }, + ), + Bi::ClipDistance => ( + self.stage == St::Vertex && self.output, + match *ty_inner { + Ti::Array { base, .. } => { + self.types[base].inner + == Ti::Scalar { + kind: Sk::Float, + width, + } + } + _ => false, + }, + ), + Bi::PointSize => ( + self.stage == St::Vertex && self.output, + *ty_inner + == Ti::Scalar { + kind: Sk::Float, + width, + }, + ), + Bi::Position => ( + match self.stage { + St::Vertex => self.output, + St::Fragment => !self.output, + St::Compute => false, + }, + *ty_inner + == Ti::Vector { + size: Vs::Quad, + kind: Sk::Float, + width, + }, + ), + Bi::FragDepth => ( + self.stage == St::Fragment && self.output, + *ty_inner + == Ti::Scalar { + kind: Sk::Float, + width, + }, + ), + Bi::FrontFacing => ( + self.stage == St::Fragment && !self.output, + *ty_inner + == Ti::Scalar { + kind: Sk::Bool, + width: crate::BOOL_WIDTH, + }, + ), + Bi::SampleIndex => ( + self.stage == St::Fragment && !self.output, + *ty_inner + == Ti::Scalar { + kind: Sk::Uint, + width, + }, + ), + Bi::SampleMask => ( + self.stage == St::Fragment, + *ty_inner + == Ti::Scalar { + kind: Sk::Uint, + width, + }, + ), + Bi::LocalInvocationIndex => ( + self.stage == St::Compute && !self.output, + *ty_inner + == Ti::Scalar { + kind: Sk::Uint, + width, + }, + ), + Bi::GlobalInvocationId + | Bi::LocalInvocationId + | Bi::WorkGroupId + | Bi::WorkGroupSize => ( + self.stage == St::Compute && !self.output, + *ty_inner + == Ti::Vector { + size: Vs::Tri, + kind: Sk::Uint, + width, + }, + ), + }; + + if !visible { + return Err(VaryingError::InvalidBuiltInStage(built_in)); + } + if !type_good { + log::warn!("Wrong builtin type: {:?}", ty_inner); + return Err(VaryingError::InvalidBuiltInType(built_in)); + } + } + crate::Binding::Location(location, interpolation) => { + if !self.location_mask.insert(location as usize) { + return Err(VaryingError::BindingCollision { location }); + } + let needs_interpolation = + self.stage == crate::ShaderStage::Fragment && !self.output; + if !needs_interpolation && interpolation.is_some() { + return Err(VaryingError::InvalidInterpolation); + } + match ty_inner.scalar_kind() { + Some(crate::ScalarKind::Float) => {} + Some(_) + if needs_interpolation + && interpolation != Some(crate::Interpolation::Flat) => + { + return Err(VaryingError::InvalidInterpolation); + } + Some(_) => {} + None => return Err(VaryingError::InvalidType(self.ty)), + } + } + } + + Ok(()) + } + + fn validate(mut self, binding: Option<&crate::Binding>) -> Result<(), VaryingError> { + match binding { + Some(binding) => self.validate_impl(binding), + None => { + match self.types[self.ty].inner { + //TODO: check the member types + crate::TypeInner::Struct { + block: false, + ref members, + } => { + for (index, member) in members.iter().enumerate() { + self.ty = member.ty; + match member.binding { + None => { + return Err(VaryingError::MemberMissingBinding(index as u32)) + } + Some(ref binding) => self.validate_impl(binding)?, + } + } + } + _ => return Err(VaryingError::InvalidType(self.ty)), + } + Ok(()) + } + } + } +} + +impl super::Validator { + pub(super) fn validate_global_var( + &self, + var: &crate::GlobalVariable, + types: &Arena, + ) -> Result<(), GlobalVariableError> { + log::debug!("var {:?}", var); + let (allowed_storage_access, required_type_flags, is_resource) = match var.class { + crate::StorageClass::Function => return Err(GlobalVariableError::InvalidUsage), + crate::StorageClass::Storage => { + match types[var.ty].inner { + crate::TypeInner::Struct { .. } => (), + _ => return Err(GlobalVariableError::InvalidType), + } + ( + crate::StorageAccess::all(), + TypeFlags::DATA | TypeFlags::HOST_SHARED, + true, + ) + } + crate::StorageClass::Uniform => { + match types[var.ty].inner { + crate::TypeInner::Struct { .. } => (), + _ => return Err(GlobalVariableError::InvalidType), + } + ( + crate::StorageAccess::empty(), + TypeFlags::DATA | TypeFlags::SIZED | TypeFlags::HOST_SHARED, + true, + ) + } + crate::StorageClass::Handle => { + let access = match types[var.ty].inner { + crate::TypeInner::Image { + class: crate::ImageClass::Storage(_), + .. + } => crate::StorageAccess::all(), + crate::TypeInner::Image { .. } | crate::TypeInner::Sampler { .. } => { + crate::StorageAccess::empty() + } + _ => return Err(GlobalVariableError::InvalidType), + }; + (access, TypeFlags::empty(), true) + } + crate::StorageClass::Private | crate::StorageClass::WorkGroup => { + (crate::StorageAccess::empty(), TypeFlags::DATA, false) + } + crate::StorageClass::PushConstant => ( + crate::StorageAccess::LOAD, + TypeFlags::DATA | TypeFlags::HOST_SHARED, + false, + ), + }; + + if !allowed_storage_access.contains(var.storage_access) { + return Err(GlobalVariableError::InvalidStorageAccess { + seen: var.storage_access, + allowed: allowed_storage_access, + }); + } + + let type_flags = self.type_flags[var.ty.index()]; + if !type_flags.contains(required_type_flags) { + return Err(GlobalVariableError::MissingTypeFlags { + seen: type_flags, + required: required_type_flags, + }); + } + + if is_resource != var.binding.is_some() { + return Err(GlobalVariableError::InvalidBinding); + } + + Ok(()) + } + + pub(super) fn validate_entry_point( + &mut self, + ep: &crate::EntryPoint, + info: &FunctionInfo, + module: &crate::Module, + ) -> Result<(), EntryPointError> { + if ep.early_depth_test.is_some() && ep.stage != crate::ShaderStage::Fragment { + return Err(EntryPointError::UnexpectedEarlyDepthTest); + } + if ep.stage == crate::ShaderStage::Compute { + if ep + .workgroup_size + .iter() + .any(|&s| s == 0 || s > MAX_WORKGROUP_SIZE) + { + return Err(EntryPointError::OutOfRangeWorkgroupSize); + } + } else if ep.workgroup_size != [0; 3] { + return Err(EntryPointError::UnexpectedWorkgroupSize); + } + + self.location_mask.clear(); + for (index, fa) in ep.function.arguments.iter().enumerate() { + let ctx = VaryingContext { + ty: fa.ty, + stage: ep.stage, + output: false, + types: &module.types, + location_mask: &mut self.location_mask, + }; + ctx.validate(fa.binding.as_ref()) + .map_err(|e| EntryPointError::Argument(index as u32, e))?; + } + + self.location_mask.clear(); + if let Some(ref fr) = ep.function.result { + let ctx = VaryingContext { + ty: fr.ty, + stage: ep.stage, + output: true, + types: &module.types, + location_mask: &mut self.location_mask, + }; + ctx.validate(fr.binding.as_ref()) + .map_err(EntryPointError::Result)?; + } + + for bg in self.bind_group_masks.iter_mut() { + bg.clear(); + } + for (var_handle, var) in module.global_variables.iter() { + let usage = info[var_handle]; + if usage.is_empty() { + continue; + } + + let allowed_usage = match var.class { + crate::StorageClass::Function => unreachable!(), + crate::StorageClass::Uniform => GlobalUse::READ | GlobalUse::QUERY, + crate::StorageClass::Storage => storage_usage(var.storage_access), + crate::StorageClass::Handle => match module.types[var.ty].inner { + crate::TypeInner::Image { + class: crate::ImageClass::Storage(_), + .. + } => storage_usage(var.storage_access), + _ => GlobalUse::READ | GlobalUse::QUERY, + }, + crate::StorageClass::Private | crate::StorageClass::WorkGroup => GlobalUse::all(), + crate::StorageClass::PushConstant => GlobalUse::READ, + }; + if !allowed_usage.contains(usage) { + log::warn!("\tUsage error for: {:?}", var); + log::warn!( + "\tAllowed usage: {:?}, requested: {:?}", + allowed_usage, + usage + ); + return Err(EntryPointError::InvalidGlobalUsage(var_handle, usage)); + } + + if let Some(ref bind) = var.binding { + while self.bind_group_masks.len() <= bind.group as usize { + self.bind_group_masks.push(BitSet::new()); + } + if !self.bind_group_masks[bind.group as usize].insert(bind.binding as usize) { + return Err(EntryPointError::BindingCollision(var_handle)); + } + } + } + + self.validate_function(&ep.function, info, module)?; + Ok(()) + } +} diff --git a/src/valid/mod.rs b/src/valid/mod.rs new file mode 100644 index 0000000000..03ece2a5a5 --- /dev/null +++ b/src/valid/mod.rs @@ -0,0 +1,382 @@ +mod analyzer; +mod expression; +mod function; +mod interface; + +use crate::{ + arena::{Arena, Handle}, + proc::Typifier, + FastHashSet, +}; +use bit_set::BitSet; +use thiserror::Error; + +//TODO: analyze the model at the same time as we validate it, +// merge the corresponding matches over expressions and statements. +pub use analyzer::{ + AnalysisError, AnalysisFlags, ExpressionInfo, FunctionInfo, GlobalUse, ModuleInfo, Uniformity, + UniformityRequirements, +}; +pub use expression::ExpressionError; +pub use function::{CallError, FunctionError, LocalVariableError}; +pub use interface::{EntryPointError, GlobalVariableError, VaryingError}; + +bitflags::bitflags! { + #[repr(transparent)] + pub struct TypeFlags: u8 { + /// Can be used for data variables. + const DATA = 0x1; + /// The data type has known size. + const SIZED = 0x2; + /// Can be be used for interfacing between pipeline stages. + const INTERFACE = 0x4; + /// Can be used for host-shareable structures. + const HOST_SHARED = 0x8; + } +} + +#[derive(Debug)] +pub struct Validator { + analysis_flags: AnalysisFlags, + //Note: this is a bit tricky: some of the front-ends as well as backends + // already have to use the typifier, so the work here is redundant in a way. + typifier: Typifier, + type_flags: Vec, + location_mask: BitSet, + bind_group_masks: Vec, + select_cases: FastHashSet, + valid_expression_list: Vec>, + valid_expression_set: BitSet, +} + +#[derive(Clone, Debug, Error)] +pub enum TypeError { + #[error("The {0:?} scalar width {1} is not supported")] + InvalidWidth(crate::ScalarKind, crate::Bytes), + #[error("The base handle {0:?} can not be resolved")] + UnresolvedBase(Handle), + #[error("Expected data type, found {0:?}")] + InvalidData(Handle), + #[error("Structure type {0:?} can not be a block structure")] + InvalidBlockType(Handle), + #[error("Base type {0:?} for the array is invalid")] + InvalidArrayBaseType(Handle), + #[error("The constant {0:?} can not be used for an array size")] + InvalidArraySizeConstant(Handle), + #[error("Field '{0}' can't be dynamically-sized, has type {1:?}")] + InvalidDynamicArray(String, Handle), +} + +#[derive(Clone, Debug, Error)] +pub enum ConstantError { + #[error("The type doesn't match the constant")] + InvalidType, + #[error("The component handle {0:?} can not be resolved")] + UnresolvedComponent(Handle), + #[error("The array size handle {0:?} can not be resolved")] + UnresolvedSize(Handle), +} + +#[derive(Clone, Debug, Error)] +pub enum ValidationError { + #[error("Type {handle:?} '{name}' is invalid")] + Type { + handle: Handle, + name: String, + #[source] + error: TypeError, + }, + #[error("Constant {handle:?} '{name}' is invalid")] + Constant { + handle: Handle, + name: String, + #[source] + error: ConstantError, + }, + #[error("Global variable {handle:?} '{name}' is invalid")] + GlobalVariable { + handle: Handle, + name: String, + #[source] + error: GlobalVariableError, + }, + #[error("Function {handle:?} '{name}' is invalid")] + Function { + handle: Handle, + name: String, + #[source] + error: FunctionError, + }, + #[error("Entry point {name} at {stage:?} is invalid")] + EntryPoint { + stage: crate::ShaderStage, + name: String, + #[source] + error: EntryPointError, + }, + #[error(transparent)] + Analysis(#[from] AnalysisError), + #[error("Module is corrupted")] + Corrupted, +} + +impl crate::TypeInner { + fn is_sized(&self) -> bool { + match *self { + Self::Scalar { .. } + | Self::Vector { .. } + | Self::Matrix { .. } + | Self::Array { + size: crate::ArraySize::Constant(_), + .. + } + | Self::Pointer { .. } + | Self::ValuePointer { .. } + | Self::Struct { .. } => true, + Self::Array { .. } | Self::Image { .. } | Self::Sampler { .. } => false, + } + } +} + +impl Validator { + /// Construct a new validator instance. + pub fn new(analysis_flags: AnalysisFlags) -> Self { + Validator { + analysis_flags, + typifier: Typifier::new(), + type_flags: Vec::new(), + location_mask: BitSet::new(), + bind_group_masks: Vec::new(), + select_cases: FastHashSet::default(), + valid_expression_list: Vec::new(), + valid_expression_set: BitSet::new(), + } + } + + fn check_width(kind: crate::ScalarKind, width: crate::Bytes) -> bool { + match kind { + crate::ScalarKind::Bool => width == crate::BOOL_WIDTH, + _ => width == 4, + } + } + + fn validate_type( + &self, + ty: &crate::Type, + handle: Handle, + constants: &Arena, + ) -> Result { + use crate::TypeInner as Ti; + Ok(match ty.inner { + Ti::Scalar { kind, width } | Ti::Vector { kind, width, .. } => { + if !Self::check_width(kind, width) { + return Err(TypeError::InvalidWidth(kind, width)); + } + TypeFlags::DATA | TypeFlags::SIZED | TypeFlags::INTERFACE | TypeFlags::HOST_SHARED + } + Ti::Matrix { width, .. } => { + if !Self::check_width(crate::ScalarKind::Float, width) { + return Err(TypeError::InvalidWidth(crate::ScalarKind::Float, width)); + } + TypeFlags::DATA | TypeFlags::SIZED | TypeFlags::INTERFACE | TypeFlags::HOST_SHARED + } + Ti::Pointer { base, class: _ } => { + if base >= handle { + return Err(TypeError::UnresolvedBase(base)); + } + TypeFlags::DATA | TypeFlags::SIZED + } + Ti::ValuePointer { + size: _, + kind, + width, + class: _, + } => { + if !Self::check_width(kind, width) { + return Err(TypeError::InvalidWidth(kind, width)); + } + TypeFlags::SIZED //TODO: `DATA`? + } + Ti::Array { + base, + size, + stride: _, + } => { + if base >= handle { + return Err(TypeError::UnresolvedBase(base)); + } + let base_flags = self.type_flags[base.index()]; + if !base_flags.contains(TypeFlags::DATA | TypeFlags::SIZED) { + return Err(TypeError::InvalidArrayBaseType(base)); + } + + let sized_flag = match size { + crate::ArraySize::Constant(const_handle) => { + match constants.try_get(const_handle) { + Some(&crate::Constant { + inner: + crate::ConstantInner::Scalar { + width: _, + value: crate::ScalarValue::Uint(_), + }, + .. + }) => {} + // Accept a signed integer size to avoid + // requiring an explicit uint + // literal. Type inference should make + // this unnecessary. + Some(&crate::Constant { + inner: + crate::ConstantInner::Scalar { + width: _, + value: crate::ScalarValue::Sint(_), + }, + .. + }) => {} + other => { + log::warn!("Array size {:?}", other); + return Err(TypeError::InvalidArraySizeConstant(const_handle)); + } + } + TypeFlags::SIZED + } + crate::ArraySize::Dynamic => TypeFlags::empty(), + }; + let base_mask = TypeFlags::HOST_SHARED | TypeFlags::INTERFACE; + TypeFlags::DATA | (base_flags & base_mask) | sized_flag + } + Ti::Struct { block, ref members } => { + let mut flags = TypeFlags::all(); + for (i, member) in members.iter().enumerate() { + if member.ty >= handle { + return Err(TypeError::UnresolvedBase(member.ty)); + } + let base_flags = self.type_flags[member.ty.index()]; + flags &= base_flags; + if !base_flags.contains(TypeFlags::DATA) { + return Err(TypeError::InvalidData(member.ty)); + } + if block && !base_flags.contains(TypeFlags::INTERFACE) { + return Err(TypeError::InvalidBlockType(member.ty)); + } + // only the last field can be unsized + if i + 1 != members.len() && !base_flags.contains(TypeFlags::SIZED) { + let name = member.name.clone().unwrap_or_default(); + return Err(TypeError::InvalidDynamicArray(name, member.ty)); + } + } + //TODO: check the spans + flags + } + Ti::Image { .. } | Ti::Sampler { .. } => TypeFlags::empty(), + }) + } + + fn validate_constant( + &self, + handle: Handle, + constants: &Arena, + types: &Arena, + ) -> Result<(), ConstantError> { + let con = &constants[handle]; + match con.inner { + crate::ConstantInner::Scalar { width, ref value } => { + if !Self::check_width(value.scalar_kind(), width) { + return Err(ConstantError::InvalidType); + } + } + crate::ConstantInner::Composite { ty, ref components } => { + match types[ty].inner { + crate::TypeInner::Array { + size: crate::ArraySize::Dynamic, + .. + } => { + return Err(ConstantError::InvalidType); + } + crate::TypeInner::Array { + size: crate::ArraySize::Constant(size_handle), + .. + } => { + if handle <= size_handle { + return Err(ConstantError::UnresolvedSize(size_handle)); + } + } + _ => {} //TODO + } + if let Some(&comp) = components.iter().find(|&&comp| handle <= comp) { + return Err(ConstantError::UnresolvedComponent(comp)); + } + } + } + Ok(()) + } + + /// Check the given module to be valid. + pub fn validate(&mut self, module: &crate::Module) -> Result { + self.typifier.clear(); + self.type_flags.clear(); + self.type_flags + .resize(module.types.len(), TypeFlags::empty()); + + let analysis = ModuleInfo::new(module, self.analysis_flags)?; + + for (handle, constant) in module.constants.iter() { + self.validate_constant(handle, &module.constants, &module.types) + .map_err(|error| ValidationError::Constant { + handle, + name: constant.name.clone().unwrap_or_default(), + error, + })?; + } + + // doing after the globals, so that `type_flags` is ready + for (handle, ty) in module.types.iter() { + let ty_flags = self + .validate_type(ty, handle, &module.constants) + .map_err(|error| ValidationError::Type { + handle, + name: ty.name.clone().unwrap_or_default(), + error, + })?; + self.type_flags[handle.index()] = ty_flags; + } + + for (var_handle, var) in module.global_variables.iter() { + self.validate_global_var(var, &module.types) + .map_err(|error| ValidationError::GlobalVariable { + handle: var_handle, + name: var.name.clone().unwrap_or_default(), + error, + })?; + } + + for (handle, fun) in module.functions.iter() { + self.validate_function(fun, &analysis[handle], module) + .map_err(|error| ValidationError::Function { + handle, + name: fun.name.clone().unwrap_or_default(), + error, + })?; + } + + let mut ep_map = FastHashSet::default(); + for (index, ep) in module.entry_points.iter().enumerate() { + if !ep_map.insert((ep.stage, &ep.name)) { + return Err(ValidationError::EntryPoint { + stage: ep.stage, + name: ep.name.clone(), + error: EntryPointError::Conflict, + }); + } + let info = analysis.get_entry_point(index); + self.validate_entry_point(ep, info, module) + .map_err(|error| ValidationError::EntryPoint { + stage: ep.stage, + name: ep.name.clone(), + error, + })?; + } + + Ok(analysis) + } +} diff --git a/tests/parse.rs b/tests/parse.rs index 95c7fc85a7..93faa1526d 100644 --- a/tests/parse.rs +++ b/tests/parse.rs @@ -23,13 +23,13 @@ fn _check_glsl(name: &str) { defines: Default::default(), }, ) { - Ok(m) => match naga::proc::Validator::new(naga::proc::analyzer::AnalysisFlags::all()) - .validate(&m) - { - Ok(_analysis) => (), - //TODO: panic - Err(e) => log::error!("Unable to validate {}: {:?}", name, e), - }, + Ok(m) => { + match naga::valid::Validator::new(naga::valid::AnalysisFlags::all()).validate(&m) { + Ok(_info) => (), + //TODO: panic + Err(e) => log::error!("Unable to validate {}: {:?}", name, e), + } + } Err(e) => panic!("Unable to parse {}: {:?}", name, e), }; } diff --git a/tests/snapshots.rs b/tests/snapshots.rs index 8cbcd8751c..1b06427aa9 100644 --- a/tests/snapshots.rs +++ b/tests/snapshots.rs @@ -63,7 +63,7 @@ fn check_targets(module: &naga::Module, name: &str, targets: Targets) { Ok(string) => ron::de::from_str(&string).expect("Couldn't find param file"), Err(_) => Parameters::default(), }; - let analysis = naga::proc::Validator::new(naga::proc::analyzer::AnalysisFlags::all()) + let info = naga::valid::Validator::new(naga::valid::AnalysisFlags::all()) .validate(module) .unwrap(); @@ -78,7 +78,7 @@ fn check_targets(module: &naga::Module, name: &str, targets: Targets) { } if targets.contains(Targets::ANALYSIS) { let config = ron::ser::PrettyConfig::default().with_new_line("\n".to_string()); - let output = ron::ser::to_string_pretty(&analysis, config).unwrap(); + let output = ron::ser::to_string_pretty(&info, config).unwrap(); with_snapshot_settings(|| { insta::assert_snapshot!(format!("{}.info.ron", name), output); }); @@ -88,27 +88,27 @@ fn check_targets(module: &naga::Module, name: &str, targets: Targets) { #[cfg(feature = "spv-out")] { if targets.contains(Targets::SPIRV) { - check_output_spv(module, &analysis, name, ¶ms); + check_output_spv(module, &info, name, ¶ms); } } #[cfg(feature = "msl-out")] { if targets.contains(Targets::METAL) { - check_output_msl(module, &analysis, name, ¶ms); + check_output_msl(module, &info, name, ¶ms); } } #[cfg(feature = "glsl-out")] { if targets.contains(Targets::GLSL) { for ep in module.entry_points.iter() { - check_output_glsl(module, &analysis, name, ep.stage, &ep.name); + check_output_glsl(module, &info, name, ep.stage, &ep.name); } } } #[cfg(feature = "dot-out")] { if targets.contains(Targets::DOT) { - let string = naga::back::dot::write(module, Some(&analysis)).unwrap(); + let string = naga::back::dot::write(module, Some(&info)).unwrap(); with_snapshot_settings(|| { insta::assert_snapshot!(format!("{}.dot", name), string); }); @@ -119,7 +119,7 @@ fn check_targets(module: &naga::Module, name: &str, targets: Targets) { #[cfg(feature = "spv-out")] fn check_output_spv( module: &naga::Module, - analysis: &naga::proc::analyzer::Analysis, + info: &naga::valid::ModuleInfo, name: &str, params: &Parameters, ) { @@ -132,7 +132,7 @@ fn check_output_spv( capabilities: params.spv_capabilities.clone(), }; - let spv = spv::write_vec(module, analysis, &options).unwrap(); + let spv = spv::write_vec(module, info, &options).unwrap(); let dis = rspirv::dr::load_words(spv) .expect("Produced invalid SPIR-V") @@ -145,7 +145,7 @@ fn check_output_spv( #[cfg(feature = "msl-out")] fn check_output_msl( module: &naga::Module, - analysis: &naga::proc::analyzer::Analysis, + info: &naga::valid::ModuleInfo, name: &str, params: &Parameters, ) { @@ -178,7 +178,7 @@ fn check_output_msl( fake_missing_bindings: false, }; - let (msl, _) = msl::write_string(module, analysis, &options).unwrap(); + let (msl, _) = msl::write_string(module, info, &options).unwrap(); with_snapshot_settings(|| { insta::assert_snapshot!(format!("{}.msl", name), msl); @@ -188,7 +188,7 @@ fn check_output_msl( #[cfg(feature = "glsl-out")] fn check_output_glsl( module: &naga::Module, - analysis: &naga::proc::analyzer::Analysis, + info: &naga::valid::ModuleInfo, name: &str, stage: naga::ShaderStage, ep_name: &str, @@ -202,7 +202,7 @@ fn check_output_glsl( }; let mut buffer = Vec::new(); - let mut writer = glsl::Writer::new(&mut buffer, module, analysis, &options).unwrap(); + let mut writer = glsl::Writer::new(&mut buffer, module, info, &options).unwrap(); writer.write().unwrap(); let string = String::from_utf8(buffer).unwrap(); @@ -278,7 +278,7 @@ fn convert_spv(name: &str, targets: Targets) { ) .unwrap(); check_targets(&module, name, targets); - naga::proc::Validator::new(naga::proc::analyzer::AnalysisFlags::all()) + naga::valid::Validator::new(naga::valid::AnalysisFlags::all()) .validate(&module) .unwrap(); }