From 5f57c9eae261c2802cfdeee5a35665cb3456bdde Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Sat, 6 Jun 2020 23:20:48 -0400 Subject: [PATCH] Move the shader validation logic into a module --- wgpu-core/src/device/mod.rs | 22 +- wgpu-core/src/lib.rs | 1 + wgpu-core/src/pipeline.rs | 490 +---------------------------- wgpu-core/src/validation.rs | 593 ++++++++++++++++++++++++++++++++++++ 4 files changed, 609 insertions(+), 497 deletions(-) create mode 100644 wgpu-core/src/validation.rs diff --git a/wgpu-core/src/device/mod.rs b/wgpu-core/src/device/mod.rs index 6be9469912..959d1e2bb7 100644 --- a/wgpu-core/src/device/mod.rs +++ b/wgpu-core/src/device/mod.rs @@ -7,7 +7,7 @@ use crate::{ hub::{GfxBackend, Global, GlobalIdentityHandlerFactory, Input, Token}, id, pipeline, resource, swap_chain, track::{BufferState, TextureState, TrackerSet}, - FastHashMap, LifeGuard, PrivateFeatures, Stored, MAX_BIND_GROUPS, + validation, FastHashMap, LifeGuard, PrivateFeatures, Stored, MAX_BIND_GROUPS, }; use arrayvec::ArrayVec; @@ -1773,7 +1773,7 @@ impl Global { &rasterization_state.clone().unwrap_or_default(), ); - let mut interface = pipeline::StageInterface::default(); + let mut interface = validation::StageInterface::default(); let mut validated_stages = wgt::ShaderStage::empty(); let desc_vbs = unsafe { @@ -1820,7 +1820,7 @@ impl Global { }); interface.insert( attribute.shader_location, - pipeline::construct_vertex_format(attribute.format), + validation::MaybeOwned::Owned(validation::map_vertex_format(attribute.format)), ); } } @@ -1923,7 +1923,7 @@ impl Global { let shader_module = &shader_module_guard[desc.vertex_stage.module]; if let Some(ref module) = shader_module.module { - interface = pipeline::validate_stage( + interface = validation::check_stage( module, &group_layouts, entry_point_name, @@ -1952,7 +1952,7 @@ impl Global { if validated_stages == wgt::ShaderStage::VERTEX { if let Some(ref module) = shader_module.module { - interface = pipeline::validate_stage( + interface = validation::check_stage( module, &group_layouts, entry_point_name, @@ -1976,10 +1976,12 @@ impl Global { if validated_stages.contains(wgt::ShaderStage::FRAGMENT) { for (i, state) in color_states.iter().enumerate() { let output = &interface[&(i as wgt::ShaderLocation)]; - if !pipeline::check_texture_format(state.format, output) { - panic!( + if !validation::check_texture_format(state.format, output) { + log::error!( "Incompatible fragment output[{}]. Shader: {:?}. Expected: {:?}", - i, state.format, &**output + i, + state.format, + &**output ); } } @@ -2158,7 +2160,7 @@ impl Global { .map(|id| &bgl_guard[id.value].entries) .collect::>(); - let interface = pipeline::StageInterface::default(); + let interface = validation::StageInterface::default(); let pipeline_stage = &desc.compute_stage; let (shader_module_guard, _) = hub.shader_modules.read(&mut token); @@ -2170,7 +2172,7 @@ impl Global { let shader_module = &shader_module_guard[pipeline_stage.module]; if let Some(ref module) = shader_module.module { - let _ = pipeline::validate_stage( + let _ = validation::check_stage( module, &group_layouts, entry_point_name, diff --git a/wgpu-core/src/lib.rs b/wgpu-core/src/lib.rs index 99271b7941..b806d50126 100644 --- a/wgpu-core/src/lib.rs +++ b/wgpu-core/src/lib.rs @@ -36,6 +36,7 @@ pub mod power; pub mod resource; pub mod swap_chain; mod track; +mod validation; pub use hal::pso::read_spirv; diff --git a/wgpu-core/src/pipeline.rs b/wgpu-core/src/pipeline.rs index aa5c7aab80..3868fe502a 100644 --- a/wgpu-core/src/pipeline.rs +++ b/wgpu-core/src/pipeline.rs @@ -3,12 +3,11 @@ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ use crate::{ - binding_model::{BindEntryMap, BindGroupLayoutEntry, BindingType}, device::RenderPassContext, id::{DeviceId, PipelineLayoutId, ShaderModuleId}, - FastHashMap, LifeGuard, RawString, RefCount, Stored, U32Array, + validation::StageError, + LifeGuard, RawString, RefCount, Stored, U32Array, }; -use spirv_headers as spirv; use std::borrow::Borrow; use wgt::{ BufferAddress, ColorStateDescriptor, DepthStencilStateDescriptor, IndexFormat, InputStepMode, @@ -52,489 +51,6 @@ pub struct ProgrammableStageDescriptor { pub entry_point: RawString, } -#[derive(Clone, Debug)] -pub enum BindingError { - /// The binding is missing from the pipeline layout. - Missing, - /// The visibility flags don't include the shader stage. - Invisible, - /// The load/store access flags don't match the shader. - WrongUsage(naga::GlobalUse), - /// The type on the shader side does not match the pipeline binding. - WrongType, - /// The view dimension doesn't match the shader. - WrongTextureViewDimension { dim: spirv::Dim, is_array: bool }, - /// The component type of a sampled texture doesn't match the shader. - WrongTextureComponentType(Option), - /// Texture sampling capability doesn't match with the shader. - WrongTextureSampled, - /// The multisampled flag doesn't match. - WrongTextureMultisampled, -} - -#[derive(Clone, Debug)] -pub enum InputError { - /// The input is not provided by the earlier stage in the pipeline. - Missing, - /// The input type is not compatible with the provided. - WrongType, -} - -/// Errors produced when validating a programmable stage of a pipeline. -#[derive(Clone, Debug)] -pub enum ProgrammableStageError { - /// Unable to find an entry point matching the specified execution model. - MissingEntryPoint(spirv::ExecutionModel), - /// Error matching a global binding against the pipeline layout. - Binding { - set: u32, - binding: u32, - error: BindingError, - }, - /// Error matching the stage input against the previous stage outputs. - Input { - location: wgt::ShaderLocation, - error: InputError, - }, -} - -fn validate_binding( - module: &naga::Module, - var: &naga::GlobalVariable, - entry: &BindGroupLayoutEntry, - usage: naga::GlobalUse, -) -> Result<(), BindingError> { - let mut ty_inner = &module.types[var.ty].inner; - //TODO: change naga's IR to avoid a pointer here - if let naga::TypeInner::Pointer { base, class: _ } = *ty_inner { - ty_inner = &module.types[base].inner; - } - let allowed_usage = match *ty_inner { - naga::TypeInner::Struct { .. } => match entry.ty { - BindingType::UniformBuffer => naga::GlobalUse::LOAD, - BindingType::StorageBuffer => naga::GlobalUse::all(), - BindingType::ReadonlyStorageBuffer => naga::GlobalUse::LOAD, - _ => return Err(BindingError::WrongType), - }, - naga::TypeInner::Sampler => match entry.ty { - BindingType::Sampler | BindingType::ComparisonSampler => naga::GlobalUse::empty(), - _ => return Err(BindingError::WrongType), - }, - naga::TypeInner::Image { base, dim, flags } => { - if entry.multisampled != flags.contains(naga::ImageFlags::MULTISAMPLED) { - return Err(BindingError::WrongTextureMultisampled); - } - if flags.contains(naga::ImageFlags::ARRAYED) { - match (dim, entry.view_dimension) { - (spirv::Dim::Dim2D, wgt::TextureViewDimension::D2Array) => (), - (spirv::Dim::DimCube, wgt::TextureViewDimension::CubeArray) => (), - _ => { - return Err(BindingError::WrongTextureViewDimension { - dim, - is_array: true, - }) - } - } - } else { - match (dim, entry.view_dimension) { - (spirv::Dim::Dim1D, wgt::TextureViewDimension::D1) => (), - (spirv::Dim::Dim2D, wgt::TextureViewDimension::D2) => (), - (spirv::Dim::Dim3D, wgt::TextureViewDimension::D3) => (), - (spirv::Dim::DimCube, wgt::TextureViewDimension::Cube) => (), - _ => { - return Err(BindingError::WrongTextureViewDimension { - dim, - is_array: false, - }) - } - } - } - let (allowed_usage, is_sampled) = match entry.ty { - BindingType::SampledTexture => { - let expected_scalar_kind = match entry.texture_component_type { - wgt::TextureComponentType::Float => naga::ScalarKind::Float, - wgt::TextureComponentType::Sint => naga::ScalarKind::Sint, - wgt::TextureComponentType::Uint => naga::ScalarKind::Uint, - }; - match module.types[base].inner { - naga::TypeInner::Scalar { kind, .. } - | naga::TypeInner::Vector { kind, .. } - if kind == expected_scalar_kind => {} - naga::TypeInner::Scalar { kind, .. } - | naga::TypeInner::Vector { kind, .. } => { - return Err(BindingError::WrongTextureComponentType(Some(kind))) - } - _ => return Err(BindingError::WrongTextureComponentType(None)), - }; - (naga::GlobalUse::LOAD, true) - } - BindingType::ReadonlyStorageTexture => { - //TODO: check entry.storage_texture_format - (naga::GlobalUse::LOAD, false) - } - BindingType::WriteonlyStorageTexture => (naga::GlobalUse::STORE, false), - _ => return Err(BindingError::WrongType), - }; - if is_sampled != flags.contains(naga::ImageFlags::SAMPLED) { - return Err(BindingError::WrongTextureSampled); - } - allowed_usage - } - _ => return Err(BindingError::WrongType), - }; - if allowed_usage.contains(usage) { - Ok(()) - } else { - Err(BindingError::WrongUsage(usage)) - } -} - -fn is_sub_type(sub: &naga::TypeInner, provided: &naga::TypeInner) -> bool { - use naga::TypeInner as Ti; - - match (sub, provided) { - ( - &Ti::Scalar { - kind: k0, - width: w0, - }, - &Ti::Scalar { - kind: k1, - width: w1, - }, - ) => k0 == k1 && w0 <= w1, - ( - &Ti::Scalar { - kind: k0, - width: w0, - }, - &Ti::Vector { - size: _, - kind: k1, - width: w1, - }, - ) => k0 == k1 && w0 <= w1, - ( - &Ti::Vector { - size: s0, - kind: k0, - width: w0, - }, - &Ti::Vector { - size: s1, - kind: k1, - width: w1, - }, - ) => s0 as u8 <= s1 as u8 && k0 == k1 && w0 <= w1, - ( - &Ti::Matrix { - columns: c0, - rows: r0, - kind: k0, - width: w0, - }, - &Ti::Matrix { - columns: c1, - rows: r1, - kind: k1, - width: w1, - }, - ) => c0 == c1 && r0 == r1 && k0 == k1 && w0 <= w1, - (&Ti::Struct { members: ref m0 }, &Ti::Struct { members: ref m1 }) => m0 == m1, - _ => false, - } -} - -pub(crate) enum MaybeOwned<'a, T> { - Owned(T), - Borrowed(&'a T), -} - -impl<'a, T> std::ops::Deref for MaybeOwned<'a, T> { - type Target = T; - fn deref(&self) -> &T { - match *self { - MaybeOwned::Owned(ref value) => value, - MaybeOwned::Borrowed(value) => value, - } - } -} - -pub(crate) fn construct_vertex_format<'a>( - format: wgt::VertexFormat, -) -> MaybeOwned<'a, naga::TypeInner> { - use naga::TypeInner as Ti; - use wgt::VertexFormat as Vf; - MaybeOwned::Owned(match format { - Vf::Uchar2 => Ti::Vector { - size: naga::VectorSize::Bi, - kind: naga::ScalarKind::Uint, - width: 8, - }, - Vf::Uchar4 => Ti::Vector { - size: naga::VectorSize::Quad, - kind: naga::ScalarKind::Uint, - width: 8, - }, - Vf::Char2 => Ti::Vector { - size: naga::VectorSize::Bi, - kind: naga::ScalarKind::Sint, - width: 8, - }, - Vf::Char4 => Ti::Vector { - size: naga::VectorSize::Quad, - kind: naga::ScalarKind::Sint, - width: 8, - }, - Vf::Uchar2Norm => Ti::Vector { - size: naga::VectorSize::Bi, - kind: naga::ScalarKind::Float, - width: 8, - }, - Vf::Uchar4Norm => Ti::Vector { - size: naga::VectorSize::Quad, - kind: naga::ScalarKind::Float, - width: 8, - }, - Vf::Char2Norm => Ti::Vector { - size: naga::VectorSize::Bi, - kind: naga::ScalarKind::Float, - width: 8, - }, - Vf::Char4Norm => Ti::Vector { - size: naga::VectorSize::Quad, - kind: naga::ScalarKind::Float, - width: 8, - }, - Vf::Ushort2 => Ti::Vector { - size: naga::VectorSize::Bi, - kind: naga::ScalarKind::Uint, - width: 16, - }, - Vf::Ushort4 => Ti::Vector { - size: naga::VectorSize::Quad, - kind: naga::ScalarKind::Uint, - width: 16, - }, - Vf::Short2 => Ti::Vector { - size: naga::VectorSize::Bi, - kind: naga::ScalarKind::Sint, - width: 16, - }, - Vf::Short4 => Ti::Vector { - size: naga::VectorSize::Quad, - kind: naga::ScalarKind::Sint, - width: 16, - }, - Vf::Ushort2Norm | Vf::Short2Norm | Vf::Half2 => Ti::Vector { - size: naga::VectorSize::Bi, - kind: naga::ScalarKind::Float, - width: 16, - }, - Vf::Ushort4Norm | Vf::Short4Norm | Vf::Half4 => Ti::Vector { - size: naga::VectorSize::Quad, - kind: naga::ScalarKind::Float, - width: 16, - }, - Vf::Float => Ti::Scalar { - kind: naga::ScalarKind::Float, - width: 32, - }, - Vf::Float2 => Ti::Vector { - size: naga::VectorSize::Bi, - kind: naga::ScalarKind::Float, - width: 32, - }, - Vf::Float3 => Ti::Vector { - size: naga::VectorSize::Tri, - kind: naga::ScalarKind::Float, - width: 32, - }, - Vf::Float4 => Ti::Vector { - size: naga::VectorSize::Quad, - kind: naga::ScalarKind::Float, - width: 32, - }, - Vf::Uint => Ti::Scalar { - kind: naga::ScalarKind::Uint, - width: 32, - }, - Vf::Uint2 => Ti::Vector { - size: naga::VectorSize::Bi, - kind: naga::ScalarKind::Uint, - width: 32, - }, - Vf::Uint3 => Ti::Vector { - size: naga::VectorSize::Tri, - kind: naga::ScalarKind::Uint, - width: 32, - }, - Vf::Uint4 => Ti::Vector { - size: naga::VectorSize::Quad, - kind: naga::ScalarKind::Uint, - width: 32, - }, - Vf::Int => Ti::Scalar { - kind: naga::ScalarKind::Sint, - width: 32, - }, - Vf::Int2 => Ti::Vector { - size: naga::VectorSize::Bi, - kind: naga::ScalarKind::Sint, - width: 32, - }, - Vf::Int3 => Ti::Vector { - size: naga::VectorSize::Tri, - kind: naga::ScalarKind::Sint, - width: 32, - }, - Vf::Int4 => Ti::Vector { - size: naga::VectorSize::Quad, - kind: naga::ScalarKind::Sint, - width: 32, - }, - }) -} - -/// Return true if the fragment `format` is covered by the provided `output`. -pub(crate) fn check_texture_format( - format: wgt::TextureFormat, - output: &MaybeOwned, -) -> bool { - use naga::ScalarKind as Sk; - use wgt::TextureFormat as Tf; - - let (components, kind, width) = match *&**output { - naga::TypeInner::Scalar { kind, width } => (1, kind, width), - naga::TypeInner::Vector { size, kind, width } => (size as u8, kind, width), - _ => return false, - }; - let (req_components, req_kind, req_width) = match format { - Tf::R8Unorm | Tf::R8Snorm => (1, Sk::Float, 8), - Tf::R8Uint => (1, Sk::Uint, 8), - Tf::R8Sint => (1, Sk::Sint, 8), - Tf::R16Uint => (1, Sk::Uint, 16), - Tf::R16Sint => (1, Sk::Sint, 16), - Tf::R16Float => (1, Sk::Float, 16), - Tf::Rg8Unorm | Tf::Rg8Snorm => (2, Sk::Float, 8), - Tf::Rg8Uint => (2, Sk::Uint, 8), - Tf::Rg8Sint => (2, Sk::Sint, 8), - Tf::R32Uint => (1, Sk::Uint, 32), - Tf::R32Sint => (1, Sk::Sint, 32), - Tf::R32Float => (1, Sk::Float, 32), - Tf::Rg16Uint => (2, Sk::Uint, 16), - Tf::Rg16Sint => (2, Sk::Sint, 16), - Tf::Rg16Float => (2, Sk::Float, 16), - Tf::Rgba8Unorm - | Tf::Rgba8UnormSrgb - | Tf::Rgba8Snorm - | Tf::Bgra8Unorm - | Tf::Bgra8UnormSrgb => (4, Sk::Float, 8), - Tf::Rgba8Uint => (4, Sk::Uint, 8), - Tf::Rgba8Sint => (4, Sk::Sint, 8), - Tf::Rgb10a2Unorm => (4, Sk::Float, 10), - Tf::Rg11b10Float => (3, Sk::Float, 11), - Tf::Rg32Uint => (2, Sk::Uint, 32), - Tf::Rg32Sint => (2, Sk::Sint, 32), - Tf::Rg32Float => (2, Sk::Float, 32), - Tf::Rgba16Uint => (4, Sk::Uint, 16), - Tf::Rgba16Sint => (4, Sk::Sint, 16), - Tf::Rgba16Float => (4, Sk::Float, 16), - Tf::Rgba32Uint => (4, Sk::Uint, 32), - Tf::Rgba32Sint => (4, Sk::Sint, 32), - Tf::Rgba32Float => (4, Sk::Float, 32), - Tf::Depth32Float | Tf::Depth24Plus | Tf::Depth24PlusStencil8 => return false, - }; - - components >= req_components && kind == req_kind && width >= req_width -} - -pub(crate) type StageInterface<'a> = - FastHashMap>; - -pub(crate) fn validate_stage<'a>( - module: &'a naga::Module, - group_layouts: &[&BindEntryMap], - entry_point_name: &str, - execution_model: spirv::ExecutionModel, - inputs: StageInterface<'a>, -) -> Result, ProgrammableStageError> { - // Since a shader module can have multiple entry points with the same name, - // we need to look for one with the right execution model. - let entry_point = module - .entry_points - .iter() - .find(|entry_point| { - entry_point.name == entry_point_name && entry_point.exec_model == execution_model - }) - .ok_or(ProgrammableStageError::MissingEntryPoint(execution_model))?; - let stage_bit = match execution_model { - spirv::ExecutionModel::Vertex => wgt::ShaderStage::VERTEX, - spirv::ExecutionModel::Fragment => wgt::ShaderStage::FRAGMENT, - spirv::ExecutionModel::GLCompute => wgt::ShaderStage::COMPUTE, - // the entry point wouldn't match otherwise - _ => unreachable!(), - }; - - let function = &module.functions[entry_point.function]; - let mut outputs = StageInterface::default(); - for ((_, var), &usage) in module.global_variables.iter().zip(&function.global_usage) { - if usage.is_empty() { - continue; - } - match var.binding { - Some(naga::Binding::Descriptor { set, binding }) => { - let result = group_layouts - .get(set as usize) - .and_then(|map| map.get(&binding)) - .ok_or(BindingError::Missing) - .and_then(|entry| { - if entry.visibility.contains(stage_bit) { - Ok(entry) - } else { - Err(BindingError::Invisible) - } - }) - .and_then(|entry| validate_binding(module, var, entry, usage)); - if let Err(error) = result { - return Err(ProgrammableStageError::Binding { - set, - binding, - error, - }); - } - } - Some(naga::Binding::Location(location)) => { - let mut ty = &module.types[var.ty].inner; - //TODO: change naga's IR to not have pointer for varyings - if let naga::TypeInner::Pointer { base, class: _ } = *ty { - ty = &module.types[base].inner; - } - if usage.contains(naga::GlobalUse::STORE) { - outputs.insert(location, MaybeOwned::Borrowed(ty)); - } else { - let result = - inputs - .get(&location) - .ok_or(InputError::Missing) - .and_then(|provided| { - if is_sub_type(ty, provided) { - Ok(()) - } else { - Err(InputError::WrongType) - } - }); - if let Err(error) = result { - return Err(ProgrammableStageError::Input { location, error }); - } - } - } - _ => {} - } - } - Ok(outputs) -} - #[repr(C)] #[derive(Debug)] pub struct ComputePipelineDescriptor { @@ -544,7 +60,7 @@ pub struct ComputePipelineDescriptor { #[derive(Clone, Debug)] pub enum ComputePipelineError { - Stage(ProgrammableStageError), + Stage(StageError), } #[derive(Debug)] diff --git a/wgpu-core/src/validation.rs b/wgpu-core/src/validation.rs new file mode 100644 index 0000000000..f46a2ef265 --- /dev/null +++ b/wgpu-core/src/validation.rs @@ -0,0 +1,593 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +use crate::{ + binding_model::{BindEntryMap, BindGroupLayoutEntry, BindingType}, + FastHashMap, +}; +use spirv_headers as spirv; + +#[derive(Clone, Debug)] +pub enum BindingError { + /// The binding is missing from the pipeline layout. + Missing, + /// The visibility flags don't include the shader stage. + Invisible, + /// The load/store access flags don't match the shader. + WrongUsage(naga::GlobalUse), + /// The type on the shader side does not match the pipeline binding. + WrongType, + /// The view dimension doesn't match the shader. + WrongTextureViewDimension { dim: spirv::Dim, is_array: bool }, + /// The component type of a sampled texture doesn't match the shader. + WrongTextureComponentType(Option), + /// Texture sampling capability doesn't match with the shader. + WrongTextureSampled, + /// The multisampled flag doesn't match. + WrongTextureMultisampled, +} + +#[derive(Clone, Debug)] +pub enum InputError { + /// The input is not provided by the earlier stage in the pipeline. + Missing, + /// The input type is not compatible with the provided. + WrongType, +} + +/// Errors produced when validating a programmable stage of a pipeline. +#[derive(Clone, Debug)] +pub enum StageError { + /// Unable to find an entry point matching the specified execution model. + MissingEntryPoint(spirv::ExecutionModel), + /// Error matching a global binding against the pipeline layout. + Binding { + set: u32, + binding: u32, + error: BindingError, + }, + /// Error matching the stage input against the previous stage outputs. + Input { + location: wgt::ShaderLocation, + error: InputError, + }, +} + +fn check_binding( + module: &naga::Module, + var: &naga::GlobalVariable, + entry: &BindGroupLayoutEntry, + usage: naga::GlobalUse, +) -> Result<(), BindingError> { + let mut ty_inner = &module.types[var.ty].inner; + //TODO: change naga's IR to avoid a pointer here + if let naga::TypeInner::Pointer { base, class: _ } = *ty_inner { + ty_inner = &module.types[base].inner; + } + let allowed_usage = match *ty_inner { + naga::TypeInner::Struct { .. } => match entry.ty { + BindingType::UniformBuffer => naga::GlobalUse::LOAD, + BindingType::StorageBuffer => naga::GlobalUse::all(), + BindingType::ReadonlyStorageBuffer => naga::GlobalUse::LOAD, + _ => return Err(BindingError::WrongType), + }, + naga::TypeInner::Sampler => match entry.ty { + BindingType::Sampler | BindingType::ComparisonSampler => naga::GlobalUse::empty(), + _ => return Err(BindingError::WrongType), + }, + naga::TypeInner::Image { base, dim, flags } => { + if entry.multisampled != flags.contains(naga::ImageFlags::MULTISAMPLED) { + return Err(BindingError::WrongTextureMultisampled); + } + if flags.contains(naga::ImageFlags::ARRAYED) { + match (dim, entry.view_dimension) { + (spirv::Dim::Dim2D, wgt::TextureViewDimension::D2Array) => (), + (spirv::Dim::DimCube, wgt::TextureViewDimension::CubeArray) => (), + _ => { + return Err(BindingError::WrongTextureViewDimension { + dim, + is_array: true, + }) + } + } + } else { + match (dim, entry.view_dimension) { + (spirv::Dim::Dim1D, wgt::TextureViewDimension::D1) => (), + (spirv::Dim::Dim2D, wgt::TextureViewDimension::D2) => (), + (spirv::Dim::Dim3D, wgt::TextureViewDimension::D3) => (), + (spirv::Dim::DimCube, wgt::TextureViewDimension::Cube) => (), + _ => { + return Err(BindingError::WrongTextureViewDimension { + dim, + is_array: false, + }) + } + } + } + let (allowed_usage, is_sampled) = match entry.ty { + BindingType::SampledTexture => { + let expected_scalar_kind = match entry.texture_component_type { + wgt::TextureComponentType::Float => naga::ScalarKind::Float, + wgt::TextureComponentType::Sint => naga::ScalarKind::Sint, + wgt::TextureComponentType::Uint => naga::ScalarKind::Uint, + }; + match module.types[base].inner { + naga::TypeInner::Scalar { kind, .. } + | naga::TypeInner::Vector { kind, .. } + if kind == expected_scalar_kind => {} + naga::TypeInner::Scalar { kind, .. } + | naga::TypeInner::Vector { kind, .. } => { + return Err(BindingError::WrongTextureComponentType(Some(kind))) + } + _ => return Err(BindingError::WrongTextureComponentType(None)), + }; + (naga::GlobalUse::LOAD, true) + } + BindingType::ReadonlyStorageTexture => { + //TODO: check entry.storage_texture_format + (naga::GlobalUse::LOAD, false) + } + BindingType::WriteonlyStorageTexture => (naga::GlobalUse::STORE, false), + _ => return Err(BindingError::WrongType), + }; + if is_sampled != flags.contains(naga::ImageFlags::SAMPLED) { + return Err(BindingError::WrongTextureSampled); + } + allowed_usage + } + _ => return Err(BindingError::WrongType), + }; + if allowed_usage.contains(usage) { + Ok(()) + } else { + Err(BindingError::WrongUsage(usage)) + } +} + +fn is_sub_type(sub: &naga::TypeInner, provided: &naga::TypeInner) -> bool { + use naga::TypeInner as Ti; + + match (sub, provided) { + ( + &Ti::Scalar { + kind: k0, + width: w0, + }, + &Ti::Scalar { + kind: k1, + width: w1, + }, + ) => k0 == k1 && w0 <= w1, + ( + &Ti::Scalar { + kind: k0, + width: w0, + }, + &Ti::Vector { + size: _, + kind: k1, + width: w1, + }, + ) => k0 == k1 && w0 <= w1, + ( + &Ti::Vector { + size: s0, + kind: k0, + width: w0, + }, + &Ti::Vector { + size: s1, + kind: k1, + width: w1, + }, + ) => s0 as u8 <= s1 as u8 && k0 == k1 && w0 <= w1, + ( + &Ti::Matrix { + columns: c0, + rows: r0, + kind: k0, + width: w0, + }, + &Ti::Matrix { + columns: c1, + rows: r1, + kind: k1, + width: w1, + }, + ) => c0 == c1 && r0 == r1 && k0 == k1 && w0 <= w1, + (&Ti::Struct { members: ref m0 }, &Ti::Struct { members: ref m1 }) => m0 == m1, + _ => false, + } +} + +pub enum MaybeOwned<'a, T> { + Owned(T), + Borrowed(&'a T), +} + +impl<'a, T> std::ops::Deref for MaybeOwned<'a, T> { + type Target = T; + fn deref(&self) -> &T { + match *self { + MaybeOwned::Owned(ref value) => value, + MaybeOwned::Borrowed(value) => value, + } + } +} + +pub fn map_vertex_format(format: wgt::VertexFormat) -> naga::TypeInner { + use naga::TypeInner as Ti; + use wgt::VertexFormat as Vf; + match format { + Vf::Uchar2 => Ti::Vector { + size: naga::VectorSize::Bi, + kind: naga::ScalarKind::Uint, + width: 8, + }, + Vf::Uchar4 => Ti::Vector { + size: naga::VectorSize::Quad, + kind: naga::ScalarKind::Uint, + width: 8, + }, + Vf::Char2 => Ti::Vector { + size: naga::VectorSize::Bi, + kind: naga::ScalarKind::Sint, + width: 8, + }, + Vf::Char4 => Ti::Vector { + size: naga::VectorSize::Quad, + kind: naga::ScalarKind::Sint, + width: 8, + }, + Vf::Uchar2Norm => Ti::Vector { + size: naga::VectorSize::Bi, + kind: naga::ScalarKind::Float, + width: 8, + }, + Vf::Uchar4Norm => Ti::Vector { + size: naga::VectorSize::Quad, + kind: naga::ScalarKind::Float, + width: 8, + }, + Vf::Char2Norm => Ti::Vector { + size: naga::VectorSize::Bi, + kind: naga::ScalarKind::Float, + width: 8, + }, + Vf::Char4Norm => Ti::Vector { + size: naga::VectorSize::Quad, + kind: naga::ScalarKind::Float, + width: 8, + }, + Vf::Ushort2 => Ti::Vector { + size: naga::VectorSize::Bi, + kind: naga::ScalarKind::Uint, + width: 16, + }, + Vf::Ushort4 => Ti::Vector { + size: naga::VectorSize::Quad, + kind: naga::ScalarKind::Uint, + width: 16, + }, + Vf::Short2 => Ti::Vector { + size: naga::VectorSize::Bi, + kind: naga::ScalarKind::Sint, + width: 16, + }, + Vf::Short4 => Ti::Vector { + size: naga::VectorSize::Quad, + kind: naga::ScalarKind::Sint, + width: 16, + }, + Vf::Ushort2Norm | Vf::Short2Norm | Vf::Half2 => Ti::Vector { + size: naga::VectorSize::Bi, + kind: naga::ScalarKind::Float, + width: 16, + }, + Vf::Ushort4Norm | Vf::Short4Norm | Vf::Half4 => Ti::Vector { + size: naga::VectorSize::Quad, + kind: naga::ScalarKind::Float, + width: 16, + }, + Vf::Float => Ti::Scalar { + kind: naga::ScalarKind::Float, + width: 32, + }, + Vf::Float2 => Ti::Vector { + size: naga::VectorSize::Bi, + kind: naga::ScalarKind::Float, + width: 32, + }, + Vf::Float3 => Ti::Vector { + size: naga::VectorSize::Tri, + kind: naga::ScalarKind::Float, + width: 32, + }, + Vf::Float4 => Ti::Vector { + size: naga::VectorSize::Quad, + kind: naga::ScalarKind::Float, + width: 32, + }, + Vf::Uint => Ti::Scalar { + kind: naga::ScalarKind::Uint, + width: 32, + }, + Vf::Uint2 => Ti::Vector { + size: naga::VectorSize::Bi, + kind: naga::ScalarKind::Uint, + width: 32, + }, + Vf::Uint3 => Ti::Vector { + size: naga::VectorSize::Tri, + kind: naga::ScalarKind::Uint, + width: 32, + }, + Vf::Uint4 => Ti::Vector { + size: naga::VectorSize::Quad, + kind: naga::ScalarKind::Uint, + width: 32, + }, + Vf::Int => Ti::Scalar { + kind: naga::ScalarKind::Sint, + width: 32, + }, + Vf::Int2 => Ti::Vector { + size: naga::VectorSize::Bi, + kind: naga::ScalarKind::Sint, + width: 32, + }, + Vf::Int3 => Ti::Vector { + size: naga::VectorSize::Tri, + kind: naga::ScalarKind::Sint, + width: 32, + }, + Vf::Int4 => Ti::Vector { + size: naga::VectorSize::Quad, + kind: naga::ScalarKind::Sint, + width: 32, + }, + } +} + +fn map_texture_format(format: wgt::TextureFormat) -> naga::TypeInner { + use naga::{ScalarKind as Sk, TypeInner as Ti, VectorSize as Vs}; + use wgt::TextureFormat as Tf; + + match format { + Tf::R8Unorm | Tf::R8Snorm => Ti::Scalar { + kind: Sk::Float, + width: 8, + }, + Tf::R8Uint => Ti::Scalar { + kind: Sk::Uint, + width: 8, + }, + Tf::R8Sint => Ti::Scalar { + kind: Sk::Sint, + width: 8, + }, + Tf::R16Uint => Ti::Scalar { + kind: Sk::Uint, + width: 16, + }, + Tf::R16Sint => Ti::Scalar { + kind: Sk::Sint, + width: 16, + }, + Tf::R16Float => Ti::Scalar { + kind: Sk::Float, + width: 16, + }, + Tf::Rg8Unorm | Tf::Rg8Snorm => Ti::Vector { + size: Vs::Bi, + kind: Sk::Float, + width: 8, + }, + Tf::Rg8Uint => Ti::Vector { + size: Vs::Bi, + kind: Sk::Uint, + width: 8, + }, + Tf::Rg8Sint => Ti::Vector { + size: Vs::Bi, + kind: Sk::Sint, + width: 8, + }, + Tf::R32Uint => Ti::Scalar { + kind: Sk::Uint, + width: 32, + }, + Tf::R32Sint => Ti::Scalar { + kind: Sk::Sint, + width: 32, + }, + Tf::R32Float => Ti::Scalar { + kind: Sk::Float, + width: 32, + }, + Tf::Rg16Uint => Ti::Vector { + size: Vs::Bi, + kind: Sk::Uint, + width: 16, + }, + Tf::Rg16Sint => Ti::Vector { + size: Vs::Bi, + kind: Sk::Sint, + width: 16, + }, + Tf::Rg16Float => Ti::Vector { + size: Vs::Bi, + kind: Sk::Float, + width: 16, + }, + Tf::Rgba8Unorm + | Tf::Rgba8UnormSrgb + | Tf::Rgba8Snorm + | Tf::Bgra8Unorm + | Tf::Bgra8UnormSrgb => Ti::Vector { + size: Vs::Quad, + kind: Sk::Float, + width: 8, + }, + Tf::Rgba8Uint => Ti::Vector { + size: Vs::Quad, + kind: Sk::Uint, + width: 8, + }, + Tf::Rgba8Sint => Ti::Vector { + size: Vs::Quad, + kind: Sk::Sint, + width: 8, + }, + Tf::Rgb10a2Unorm => Ti::Vector { + size: Vs::Quad, + kind: Sk::Float, + width: 10, + }, + Tf::Rg11b10Float => Ti::Vector { + size: Vs::Tri, + kind: Sk::Float, + width: 11, + }, + Tf::Rg32Uint => Ti::Vector { + size: Vs::Bi, + kind: Sk::Uint, + width: 32, + }, + Tf::Rg32Sint => Ti::Vector { + size: Vs::Bi, + kind: Sk::Sint, + width: 32, + }, + Tf::Rg32Float => Ti::Vector { + size: Vs::Bi, + kind: Sk::Float, + width: 32, + }, + Tf::Rgba16Uint => Ti::Vector { + size: Vs::Quad, + kind: Sk::Uint, + width: 16, + }, + Tf::Rgba16Sint => Ti::Vector { + size: Vs::Quad, + kind: Sk::Sint, + width: 16, + }, + Tf::Rgba16Float => Ti::Vector { + size: Vs::Quad, + kind: Sk::Float, + width: 16, + }, + Tf::Rgba32Uint => Ti::Vector { + size: Vs::Quad, + kind: Sk::Uint, + width: 32, + }, + Tf::Rgba32Sint => Ti::Vector { + size: Vs::Quad, + kind: Sk::Sint, + width: 32, + }, + Tf::Rgba32Float => Ti::Vector { + size: Vs::Quad, + kind: Sk::Float, + width: 32, + }, + Tf::Depth32Float | Tf::Depth24Plus | Tf::Depth24PlusStencil8 => { + panic!("Unexpected depth format") + } + } +} + +/// Return true if the fragment `format` is covered by the provided `output`. +pub fn check_texture_format(format: wgt::TextureFormat, output: &naga::TypeInner) -> bool { + let required = map_texture_format(format); + is_sub_type(&required, output) +} + +pub type StageInterface<'a> = FastHashMap>; + +pub fn check_stage<'a>( + module: &'a naga::Module, + group_layouts: &[&BindEntryMap], + entry_point_name: &str, + execution_model: spirv::ExecutionModel, + inputs: StageInterface<'a>, +) -> Result, StageError> { + // Since a shader module can have multiple entry points with the same name, + // we need to look for one with the right execution model. + let entry_point = module + .entry_points + .iter() + .find(|entry_point| { + entry_point.name == entry_point_name && entry_point.exec_model == execution_model + }) + .ok_or(StageError::MissingEntryPoint(execution_model))?; + let stage_bit = match execution_model { + spirv::ExecutionModel::Vertex => wgt::ShaderStage::VERTEX, + spirv::ExecutionModel::Fragment => wgt::ShaderStage::FRAGMENT, + spirv::ExecutionModel::GLCompute => wgt::ShaderStage::COMPUTE, + // the entry point wouldn't match otherwise + _ => unreachable!(), + }; + + let function = &module.functions[entry_point.function]; + let mut outputs = StageInterface::default(); + for ((_, var), &usage) in module.global_variables.iter().zip(&function.global_usage) { + if usage.is_empty() { + continue; + } + match var.binding { + Some(naga::Binding::Descriptor { set, binding }) => { + let result = group_layouts + .get(set as usize) + .and_then(|map| map.get(&binding)) + .ok_or(BindingError::Missing) + .and_then(|entry| { + if entry.visibility.contains(stage_bit) { + Ok(entry) + } else { + Err(BindingError::Invisible) + } + }) + .and_then(|entry| check_binding(module, var, entry, usage)); + if let Err(error) = result { + return Err(StageError::Binding { + set, + binding, + error, + }); + } + } + Some(naga::Binding::Location(location)) => { + let mut ty = &module.types[var.ty].inner; + //TODO: change naga's IR to not have pointer for varyings + if let naga::TypeInner::Pointer { base, class: _ } = *ty { + ty = &module.types[base].inner; + } + if usage.contains(naga::GlobalUse::STORE) { + outputs.insert(location, MaybeOwned::Borrowed(ty)); + } else { + let result = + inputs + .get(&location) + .ok_or(InputError::Missing) + .and_then(|provided| { + if is_sub_type(ty, provided) { + Ok(()) + } else { + Err(InputError::WrongType) + } + }); + if let Err(error) = result { + return Err(StageError::Input { location, error }); + } + } + } + _ => {} + } + } + Ok(outputs) +}