From 63afb9a215d46d848fe258c2d712ca29487e3438 Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Wed, 5 May 2021 17:52:50 -0400 Subject: [PATCH] Add Capabilities bitset for validation --- bin/naga.rs | 20 +++++++++++-------- src/back/msl/writer.rs | 4 ++-- src/valid/mod.rs | 16 +++++++++++++-- src/valid/type.rs | 16 +++++++++------ tests/parse.rs | 44 ------------------------------------------ tests/snapshots.rs | 18 +++++++++++------ 6 files changed, 50 insertions(+), 68 deletions(-) delete mode 100644 tests/parse.rs diff --git a/bin/naga.rs b/bin/naga.rs index d9b841ebc5..efa7ddd760 100644 --- a/bin/naga.rs +++ b/bin/naga.rs @@ -175,14 +175,18 @@ fn main() { }; // validate the IR - let info = - match naga::valid::Validator::new(naga::valid::ValidationFlags::all()).validate(&module) { - Ok(info) => Some(info), - Err(error) => { - print_err(error); - None - } - }; + let info = match naga::valid::Validator::new( + naga::valid::ValidationFlags::all(), + naga::valid::Capabilities::all(), + ) + .validate(&module) + { + Ok(info) => Some(info), + Err(error) => { + print_err(error); + None + } + }; let output_path = match output_path { Some(ref string) => string, diff --git a/src/back/msl/writer.rs b/src/back/msl/writer.rs index e08f0d61dc..3b1f9e3e02 100644 --- a/src/back/msl/writer.rs +++ b/src/back/msl/writer.rs @@ -2307,7 +2307,7 @@ impl Writer { #[test] fn test_stack_size() { - use crate::valid::ValidationFlags; + use crate::valid::{Capabilities, ValidationFlags}; // create a module with at least one expression nested let mut module = crate::Module::default(); let constant = module.constants.append(crate::Constant { @@ -2335,7 +2335,7 @@ fn test_stack_size() { }); let _ = module.functions.append(fun); // analyse the module - let info = crate::valid::Validator::new(ValidationFlags::empty()) + let info = crate::valid::Validator::new(ValidationFlags::empty(), Capabilities::empty()) .validate(&module) .unwrap(); // process the module diff --git a/src/valid/mod.rs b/src/valid/mod.rs index 1191e77b83..8f3ac946ba 100644 --- a/src/valid/mod.rs +++ b/src/valid/mod.rs @@ -41,6 +41,16 @@ bitflags::bitflags! { } #[must_use] +bitflags::bitflags! { + /// Allowed IR capabilities. + #[cfg_attr(feature = "serialize", derive(serde::Serialize))] + #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] + pub struct Capabilities: u8 { + /// Float values with width = 8 + const FLOAT64 = 0x1; + } +} + bitflags::bitflags! { /// Validation flags. #[cfg_attr(feature = "serialize", derive(serde::Serialize))] @@ -70,6 +80,7 @@ impl ops::Index> for ModuleInfo { #[derive(Debug)] pub struct Validator { flags: ValidationFlags, + capabilities: Capabilities, types: Vec, location_mask: BitSet, bind_group_masks: Vec, @@ -171,9 +182,10 @@ impl crate::TypeInner { impl Validator { /// Construct a new validator instance. - pub fn new(flags: ValidationFlags) -> Self { + pub fn new(flags: ValidationFlags, capabilities: Capabilities) -> Self { Validator { flags, + capabilities, types: Vec::new(), location_mask: BitSet::new(), bind_group_masks: Vec::new(), @@ -192,7 +204,7 @@ impl Validator { let con = &constants[handle]; match con.inner { crate::ConstantInner::Scalar { width, ref value } => { - if !Self::check_width(value.scalar_kind(), width) { + if !self.check_width(value.scalar_kind(), width) { return Err(ConstantError::InvalidType); } } diff --git a/src/valid/type.rs b/src/valid/type.rs index ff34e81fe8..056af0508d 100644 --- a/src/valid/type.rs +++ b/src/valid/type.rs @@ -1,3 +1,4 @@ +use super::Capabilities; use crate::arena::{Arena, Handle}; bitflags::bitflags! { @@ -143,10 +144,13 @@ impl TypeInfo { } impl super::Validator { - pub(super) fn check_width(kind: crate::ScalarKind, width: crate::Bytes) -> bool { + pub(super) fn check_width(&self, kind: crate::ScalarKind, width: crate::Bytes) -> bool { match kind { crate::ScalarKind::Bool => width == crate::BOOL_WIDTH, - _ => width == 4, + crate::ScalarKind::Float => { + width == 4 || (width == 8 && self.capabilities.contains(Capabilities::FLOAT64)) + } + crate::ScalarKind::Sint | crate::ScalarKind::Uint => width == 4, } } @@ -164,7 +168,7 @@ impl super::Validator { use crate::TypeInner as Ti; Ok(match types[handle].inner { Ti::Scalar { kind, width } => { - if !Self::check_width(kind, width) { + if !self.check_width(kind, width) { return Err(TypeError::InvalidWidth(kind, width)); } TypeInfo::new( @@ -176,7 +180,7 @@ impl super::Validator { ) } Ti::Vector { size, kind, width } => { - if !Self::check_width(kind, width) { + if !self.check_width(kind, width) { return Err(TypeError::InvalidWidth(kind, width)); } let count = if size >= crate::VectorSize::Tri { 4 } else { 2 }; @@ -193,7 +197,7 @@ impl super::Validator { rows, width, } => { - if !Self::check_width(crate::ScalarKind::Float, width) { + if !self.check_width(crate::ScalarKind::Float, width) { return Err(TypeError::InvalidWidth(crate::ScalarKind::Float, width)); } let count = if rows >= crate::VectorSize::Tri { 4 } else { 2 }; @@ -217,7 +221,7 @@ impl super::Validator { width, class: _, } => { - if !Self::check_width(kind, width) { + if !self.check_width(kind, width) { return Err(TypeError::InvalidWidth(kind, width)); } TypeInfo::new(TypeFlags::SIZED, 0) diff --git a/tests/parse.rs b/tests/parse.rs deleted file mode 100644 index ff8eacbc91..0000000000 --- a/tests/parse.rs +++ /dev/null @@ -1,44 +0,0 @@ -//TODO: consider converting this to snapshots? - -#[cfg(feature = "glsl-in")] -fn _check_glsl(name: &str) { - let path = std::path::PathBuf::from("tests/cases").join(name); - let input = std::fs::read_to_string(path).unwrap(); - let stage = if name.ends_with(".vert") { - naga::ShaderStage::Vertex - } else if name.ends_with(".frag") { - naga::ShaderStage::Fragment - } else if name.ends_with(".comp") { - naga::ShaderStage::Compute - } else { - panic!("Unknown extension in {:?}", name) - }; - - let mut entry_points = naga::FastHashMap::default(); - entry_points.insert("main".to_string(), stage); - match naga::front::glsl::parse_str( - &input, - &naga::front::glsl::Options { - entry_points, - defines: Default::default(), - }, - ) { - Ok(m) => { - match naga::valid::Validator::new(naga::valid::ValidationFlags::all()).validate(&m) { - Ok(_info) => (), - //TODO: panic - Err(e) => log::error!("Unable to validate {}: {:?}", name, e), - } - } - Err(e) => panic!("Unable to parse {}: {:?}", name, e), - }; -} - -#[cfg(feature = "glsl-in")] -#[test] -fn parse_glsl() { - //check_glsl("glsl_constant_expression.vert"); //TODO - //check_glsl("glsl_if_preprocessor.vert"); - //check_glsl("glsl_preprocessor_abuse.vert"); - //check_glsl("glsl_vertex_test_shader.vert"); //TODO -} diff --git a/tests/snapshots.rs b/tests/snapshots.rs index 034f91093d..90a835db71 100644 --- a/tests/snapshots.rs +++ b/tests/snapshots.rs @@ -46,9 +46,12 @@ fn check_targets(module: &naga::Module, name: &str, targets: Targets) { Ok(string) => ron::de::from_str(&string).expect("Couldn't find param file"), Err(_) => Parameters::default(), }; - let info = naga::valid::Validator::new(naga::valid::ValidationFlags::all()) - .validate(module) - .unwrap(); + let info = naga::valid::Validator::new( + naga::valid::ValidationFlags::all(), + naga::valid::Capabilities::empty(), + ) + .validate(module) + .unwrap(); let dest = PathBuf::from(root).join(DIR_OUT).join(name); @@ -282,9 +285,12 @@ fn convert_spv(name: &str, adjust_coordinate_space: bool, targets: Targets) { ) .unwrap(); check_targets(&module, name, targets); - naga::valid::Validator::new(naga::valid::ValidationFlags::all()) - .validate(&module) - .unwrap(); + naga::valid::Validator::new( + naga::valid::ValidationFlags::all(), + naga::valid::Capabilities::empty(), + ) + .validate(&module) + .unwrap(); } #[cfg(feature = "spv-in")]