Add Capabilities bitset for validation

This commit is contained in:
Dzmitry Malyshau
2021-05-05 17:52:50 -04:00
committed by Dzmitry Malyshau
parent 3a0f014411
commit 63afb9a215
6 changed files with 50 additions and 68 deletions

View File

@@ -175,14 +175,18 @@ fn main() {
};
// validate the IR
let info =
match naga::valid::Validator::new(naga::valid::ValidationFlags::all()).validate(&module) {
Ok(info) => Some(info),
Err(error) => {
print_err(error);
None
}
};
let info = match naga::valid::Validator::new(
naga::valid::ValidationFlags::all(),
naga::valid::Capabilities::all(),
)
.validate(&module)
{
Ok(info) => Some(info),
Err(error) => {
print_err(error);
None
}
};
let output_path = match output_path {
Some(ref string) => string,

View File

@@ -2307,7 +2307,7 @@ impl<W: Write> Writer<W> {
#[test]
fn test_stack_size() {
use crate::valid::ValidationFlags;
use crate::valid::{Capabilities, ValidationFlags};
// create a module with at least one expression nested
let mut module = crate::Module::default();
let constant = module.constants.append(crate::Constant {
@@ -2335,7 +2335,7 @@ fn test_stack_size() {
});
let _ = module.functions.append(fun);
// analyse the module
let info = crate::valid::Validator::new(ValidationFlags::empty())
let info = crate::valid::Validator::new(ValidationFlags::empty(), Capabilities::empty())
.validate(&module)
.unwrap();
// process the module

View File

@@ -41,6 +41,16 @@ bitflags::bitflags! {
}
#[must_use]
bitflags::bitflags! {
/// Allowed IR capabilities.
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
pub struct Capabilities: u8 {
/// Float values with width = 8
const FLOAT64 = 0x1;
}
}
bitflags::bitflags! {
/// Validation flags.
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
@@ -70,6 +80,7 @@ impl ops::Index<Handle<crate::Function>> for ModuleInfo {
#[derive(Debug)]
pub struct Validator {
flags: ValidationFlags,
capabilities: Capabilities,
types: Vec<r#type::TypeInfo>,
location_mask: BitSet,
bind_group_masks: Vec<BitSet>,
@@ -171,9 +182,10 @@ impl crate::TypeInner {
impl Validator {
/// Construct a new validator instance.
pub fn new(flags: ValidationFlags) -> Self {
pub fn new(flags: ValidationFlags, capabilities: Capabilities) -> Self {
Validator {
flags,
capabilities,
types: Vec::new(),
location_mask: BitSet::new(),
bind_group_masks: Vec::new(),
@@ -192,7 +204,7 @@ impl Validator {
let con = &constants[handle];
match con.inner {
crate::ConstantInner::Scalar { width, ref value } => {
if !Self::check_width(value.scalar_kind(), width) {
if !self.check_width(value.scalar_kind(), width) {
return Err(ConstantError::InvalidType);
}
}

View File

@@ -1,3 +1,4 @@
use super::Capabilities;
use crate::arena::{Arena, Handle};
bitflags::bitflags! {
@@ -143,10 +144,13 @@ impl TypeInfo {
}
impl super::Validator {
pub(super) fn check_width(kind: crate::ScalarKind, width: crate::Bytes) -> bool {
pub(super) fn check_width(&self, kind: crate::ScalarKind, width: crate::Bytes) -> bool {
match kind {
crate::ScalarKind::Bool => width == crate::BOOL_WIDTH,
_ => width == 4,
crate::ScalarKind::Float => {
width == 4 || (width == 8 && self.capabilities.contains(Capabilities::FLOAT64))
}
crate::ScalarKind::Sint | crate::ScalarKind::Uint => width == 4,
}
}
@@ -164,7 +168,7 @@ impl super::Validator {
use crate::TypeInner as Ti;
Ok(match types[handle].inner {
Ti::Scalar { kind, width } => {
if !Self::check_width(kind, width) {
if !self.check_width(kind, width) {
return Err(TypeError::InvalidWidth(kind, width));
}
TypeInfo::new(
@@ -176,7 +180,7 @@ impl super::Validator {
)
}
Ti::Vector { size, kind, width } => {
if !Self::check_width(kind, width) {
if !self.check_width(kind, width) {
return Err(TypeError::InvalidWidth(kind, width));
}
let count = if size >= crate::VectorSize::Tri { 4 } else { 2 };
@@ -193,7 +197,7 @@ impl super::Validator {
rows,
width,
} => {
if !Self::check_width(crate::ScalarKind::Float, width) {
if !self.check_width(crate::ScalarKind::Float, width) {
return Err(TypeError::InvalidWidth(crate::ScalarKind::Float, width));
}
let count = if rows >= crate::VectorSize::Tri { 4 } else { 2 };
@@ -217,7 +221,7 @@ impl super::Validator {
width,
class: _,
} => {
if !Self::check_width(kind, width) {
if !self.check_width(kind, width) {
return Err(TypeError::InvalidWidth(kind, width));
}
TypeInfo::new(TypeFlags::SIZED, 0)

View File

@@ -1,44 +0,0 @@
//TODO: consider converting this to snapshots?
#[cfg(feature = "glsl-in")]
fn _check_glsl(name: &str) {
let path = std::path::PathBuf::from("tests/cases").join(name);
let input = std::fs::read_to_string(path).unwrap();
let stage = if name.ends_with(".vert") {
naga::ShaderStage::Vertex
} else if name.ends_with(".frag") {
naga::ShaderStage::Fragment
} else if name.ends_with(".comp") {
naga::ShaderStage::Compute
} else {
panic!("Unknown extension in {:?}", name)
};
let mut entry_points = naga::FastHashMap::default();
entry_points.insert("main".to_string(), stage);
match naga::front::glsl::parse_str(
&input,
&naga::front::glsl::Options {
entry_points,
defines: Default::default(),
},
) {
Ok(m) => {
match naga::valid::Validator::new(naga::valid::ValidationFlags::all()).validate(&m) {
Ok(_info) => (),
//TODO: panic
Err(e) => log::error!("Unable to validate {}: {:?}", name, e),
}
}
Err(e) => panic!("Unable to parse {}: {:?}", name, e),
};
}
#[cfg(feature = "glsl-in")]
#[test]
fn parse_glsl() {
//check_glsl("glsl_constant_expression.vert"); //TODO
//check_glsl("glsl_if_preprocessor.vert");
//check_glsl("glsl_preprocessor_abuse.vert");
//check_glsl("glsl_vertex_test_shader.vert"); //TODO
}

View File

@@ -46,9 +46,12 @@ fn check_targets(module: &naga::Module, name: &str, targets: Targets) {
Ok(string) => ron::de::from_str(&string).expect("Couldn't find param file"),
Err(_) => Parameters::default(),
};
let info = naga::valid::Validator::new(naga::valid::ValidationFlags::all())
.validate(module)
.unwrap();
let info = naga::valid::Validator::new(
naga::valid::ValidationFlags::all(),
naga::valid::Capabilities::empty(),
)
.validate(module)
.unwrap();
let dest = PathBuf::from(root).join(DIR_OUT).join(name);
@@ -282,9 +285,12 @@ fn convert_spv(name: &str, adjust_coordinate_space: bool, targets: Targets) {
)
.unwrap();
check_targets(&module, name, targets);
naga::valid::Validator::new(naga::valid::ValidationFlags::all())
.validate(&module)
.unwrap();
naga::valid::Validator::new(
naga::valid::ValidationFlags::all(),
naga::valid::Capabilities::empty(),
)
.validate(&module)
.unwrap();
}
#[cfg(feature = "spv-in")]