From 15cdc794fae311cfd6b88cd348dae27d4fc40a4d Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Fri, 5 Jun 2020 10:01:19 -0400 Subject: [PATCH] Shader binding validation --- player/src/main.rs | 3 +- wgpu-core/src/binding_model.rs | 4 +- wgpu-core/src/conv.rs | 4 +- wgpu-core/src/device/mod.rs | 82 +++++++-------- wgpu-core/src/hub.rs | 3 +- wgpu-core/src/instance.rs | 31 ++++-- wgpu-core/src/lib.rs | 3 +- wgpu-core/src/pipeline.rs | 185 +++++++++++++++++++++++++++++++++ 8 files changed, 262 insertions(+), 53 deletions(-) diff --git a/player/src/main.rs b/player/src/main.rs index 0d28d6661c..3b7fe8528f 100644 --- a/player/src/main.rs +++ b/player/src/main.rs @@ -355,7 +355,8 @@ impl GlobalExt for wgc::hub::Global { compute_stage: cs_stage.desc, }, id, - ); + ) + .unwrap(); } A::DestroyComputePipeline(id) => { self.compute_pipeline_destroy::(id); diff --git a/wgpu-core/src/binding_model.rs b/wgpu-core/src/binding_model.rs index 9779e8bcee..e94359aa27 100644 --- a/wgpu-core/src/binding_model.rs +++ b/wgpu-core/src/binding_model.rs @@ -93,12 +93,14 @@ pub enum BindGroupLayoutError { Entry(u32, BindGroupLayoutEntryError), } +pub(crate) type BindEntryMap = FastHashMap; + #[derive(Debug)] pub struct BindGroupLayout { pub(crate) raw: B::DescriptorSetLayout, pub(crate) device_id: Stored, pub(crate) life_guard: LifeGuard, - pub(crate) entries: FastHashMap, + pub(crate) entries: BindEntryMap, pub(crate) desc_counts: DescriptorCounts, pub(crate) dynamic_count: usize, } diff --git a/wgpu-core/src/conv.rs b/wgpu-core/src/conv.rs index df6460d310..26e449111f 100644 --- a/wgpu-core/src/conv.rs +++ b/wgpu-core/src/conv.rs @@ -371,14 +371,14 @@ pub(crate) fn map_texture_format( // Depth and stencil formats Tf::Depth32Float => H::D32Sfloat, Tf::Depth24Plus => { - if private_features.supports_texture_d24_s8 { + if private_features.texture_d24_s8 { H::D24UnormS8Uint } else { H::D32Sfloat } } Tf::Depth24PlusStencil8 => { - if private_features.supports_texture_d24_s8 { + if private_features.texture_d24_s8 { H::D24UnormS8Uint } else { H::D32SfloatS8Uint diff --git a/wgpu-core/src/device/mod.rs b/wgpu-core/src/device/mod.rs index 79e857f8f6..242b080850 100644 --- a/wgpu-core/src/device/mod.rs +++ b/wgpu-core/src/device/mod.rs @@ -7,7 +7,7 @@ use crate::{ hub::{GfxBackend, Global, GlobalIdentityHandlerFactory, Input, Token}, id, pipeline, resource, swap_chain, track::{BufferState, TextureState, TrackerSet}, - FastHashMap, LifeGuard, PrivateFeatures, Stored, + FastHashMap, LifeGuard, PrivateFeatures, Stored, MAX_BIND_GROUPS, }; use arrayvec::ArrayVec; @@ -196,7 +196,7 @@ impl Device { queue_group: hal::queue::QueueGroup, mem_props: hal::adapter::MemoryProperties, hal_limits: hal::Limits, - supports_texture_d24_s8: bool, + private_features: PrivateFeatures, desc: &wgt::DeviceDescriptor, trace_path: Option<&std::path::Path>, ) -> Self { @@ -253,9 +253,7 @@ impl Device { } }), hal_limits, - private_features: PrivateFeatures { - supports_texture_d24_s8, - }, + private_features, limits: desc.limits.clone(), extensions: desc.extensions.clone(), pending_writes: queue::PendingWrites::new(), @@ -1578,7 +1576,7 @@ impl Global { let spv = unsafe { slice::from_raw_parts(desc.code.bytes, desc.code.length) }; let raw = unsafe { device.raw.create_shader_module(spv).unwrap() }; - let module = { + let module = if device.private_features.shader_validation { // Parse the given shader code and store its representation. let spv_iter = spv.into_iter().cloned(); let mut parser = naga::front::spirv::Parser::new(spv_iter); @@ -1589,6 +1587,8 @@ impl Global { log::warn!("Shader module will not be validated"); }) .ok() + } else { + None }; let shader = pipeline::ShaderModule { raw, @@ -1859,7 +1859,14 @@ impl Global { let device = &device_guard[device_id]; let (raw_pipeline, layout_ref_count) = { let (pipeline_layout_guard, mut token) = hub.pipeline_layouts.read(&mut token); + let (bgl_guard, mut token) = hub.bind_group_layouts.read(&mut token); let layout = &pipeline_layout_guard[desc.layout]; + let group_layouts = layout + .bind_group_layout_ids + .iter() + .map(|id| &bgl_guard[id.value].entries) + .collect::>(); + let (shader_module_guard, _) = hub.shader_modules.read(&mut token); let rp_key = RenderPassKey { @@ -1909,9 +1916,12 @@ impl Global { let shader_module = &shader_module_guard[desc.vertex_stage.module]; if let Some(ref module) = shader_module.module { - if let Err(e) = - validate_shader(module, entry_point_name, ExecutionModel::Vertex) - { + if let Err(e) = pipeline::validate_stage( + module, + &group_layouts, + entry_point_name, + ExecutionModel::Vertex, + ) { log::error!("Failed validating vertex shader module: {:?}", e); } } @@ -1934,9 +1944,12 @@ impl Global { let shader_module = &shader_module_guard[stage.module]; if let Some(ref module) = shader_module.module { - if let Err(e) = - validate_shader(module, entry_point_name, ExecutionModel::Fragment) - { + if let Err(e) = pipeline::validate_stage( + module, + &group_layouts, + entry_point_name, + ExecutionModel::Fragment, + ) { log::error!("Failed validating fragment shader module: {:?}", e); } } @@ -2106,7 +2119,7 @@ impl Global { device_id: id::DeviceId, desc: &pipeline::ComputePipelineDescriptor, id_in: Input, - ) -> id::ComputePipelineId { + ) -> Result { let hub = B::hub(self); let mut token = Token::root(); @@ -2114,7 +2127,14 @@ impl Global { let device = &device_guard[device_id]; let (raw_pipeline, layout_ref_count) = { let (pipeline_layout_guard, mut token) = hub.pipeline_layouts.read(&mut token); + let (bgl_guard, mut token) = hub.bind_group_layouts.read(&mut token); let layout = &pipeline_layout_guard[desc.layout]; + let group_layouts = layout + .bind_group_layout_ids + .iter() + .map(|id| &bgl_guard[id.value].entries) + .collect::>(); + let pipeline_stage = &desc.compute_stage; let (shader_module_guard, _) = hub.shader_modules.read(&mut token); @@ -2126,9 +2146,13 @@ impl Global { let shader_module = &shader_module_guard[pipeline_stage.module]; if let Some(ref module) = shader_module.module { - if let Err(e) = validate_shader(module, entry_point_name, ExecutionModel::GLCompute) - { - log::error!("Failed validating compute shader module: {:?}", e); + if let Err(e) = pipeline::validate_stage( + module, + &group_layouts, + entry_point_name, + ExecutionModel::GLCompute, + ) { + return Err(pipeline::ComputePipelineError::Stage(e)); } } @@ -2186,7 +2210,7 @@ impl Global { }), None => (), }; - id + Ok(id) } pub fn compute_pipeline_destroy( @@ -2588,27 +2612,3 @@ impl Global { } } } - -/// Errors produced when validating the shader modules of a pipeline. -#[derive(Copy, Clone, Debug, Eq, PartialEq)] -enum ShaderValidationError { - /// Unable to find an entry point matching the specified execution model. - MissingEntryPoint(ExecutionModel), -} - -fn validate_shader( - module: &naga::Module, - entry_point_name: &str, - execution_model: ExecutionModel, -) -> Result<(), ShaderValidationError> { - // Since a shader module can have multiple entry points with the same name, - // we need to look for one with the right execution model. - let entry_point = module.entry_points.iter().find(|entry_point| { - entry_point.name == entry_point_name && entry_point.exec_model == execution_model - }); - - match entry_point { - Some(_) => Ok(()), - None => Err(ShaderValidationError::MissingEntryPoint(execution_model)), - } -} diff --git a/wgpu-core/src/hub.rs b/wgpu-core/src/hub.rs index 1e71733e22..9f2c1714f4 100644 --- a/wgpu-core/src/hub.rs +++ b/wgpu-core/src/hub.rs @@ -177,6 +177,7 @@ impl Access> for Device {} impl Access> for CommandBuffer {} impl Access> for Root {} impl Access> for Device {} +impl Access> for PipelineLayout {} impl Access> for Root {} impl Access> for Device {} impl Access> for BindGroupLayout {} @@ -191,7 +192,7 @@ impl Access> for Device {} impl Access> for BindGroup {} impl Access> for ComputePipeline {} impl Access> for Device {} -impl Access> for PipelineLayout {} +impl Access> for BindGroupLayout {} impl Access> for Root {} impl Access> for Device {} impl Access> for BindGroupLayout {} diff --git a/wgpu-core/src/instance.rs b/wgpu-core/src/instance.rs index bb7c6c1b81..7bd3f875fb 100644 --- a/wgpu-core/src/instance.rs +++ b/wgpu-core/src/instance.rs @@ -7,7 +7,7 @@ use crate::{ device::Device, hub::{GfxBackend, Global, GlobalIdentityHandlerFactory, Input, Token}, id::{AdapterId, DeviceId, SurfaceId}, - power, LifeGuard, Stored, MAX_BIND_GROUPS, + power, LifeGuard, PrivateFeatures, Stored, MAX_BIND_GROUPS, }; use wgt::{Backend, BackendBit, DeviceDescriptor, PowerPreference, BIND_BUFFER_ALIGNMENT}; @@ -644,10 +644,29 @@ impl Global { } let mem_props = phd.memory_properties(); - let supports_texture_d24_s8 = phd - .format_properties(Some(hal::format::Format::D24UnormS8Uint)) - .optimal_tiling - .contains(hal::format::ImageFeature::DEPTH_STENCIL_ATTACHMENT); + let private_features = PrivateFeatures { + shader_validation: match std::env::var("WGPU_SHADER_VALIDATION") { + Ok(var) => match var.as_str() { + "0" => { + log::info!("Shader validation is disabled"); + false + } + "1" => { + log::info!("Shader validation is enabled"); + true + } + _ => { + log::warn!("Unknown shader validation setting: {:?}", var); + true + } + }, + _ => true, + }, + texture_d24_s8: phd + .format_properties(Some(hal::format::Format::D24UnormS8Uint)) + .optimal_tiling + .contains(hal::format::ImageFeature::DEPTH_STENCIL_ATTACHMENT), + }; Device::new( gpu.device, @@ -658,7 +677,7 @@ impl Global { gpu.queue_groups.swap_remove(0), mem_props, limits, - supports_texture_d24_s8, + private_features, desc, trace_path, ) diff --git a/wgpu-core/src/lib.rs b/wgpu-core/src/lib.rs index 883519c64d..99271b7941 100644 --- a/wgpu-core/src/lib.rs +++ b/wgpu-core/src/lib.rs @@ -171,7 +171,8 @@ pub struct U32Array { #[derive(Clone, Copy, Debug)] struct PrivateFeatures { - pub supports_texture_d24_s8: bool, + shader_validation: bool, + texture_d24_s8: bool, } #[macro_export] diff --git a/wgpu-core/src/pipeline.rs b/wgpu-core/src/pipeline.rs index bff27850ba..05101e84b2 100644 --- a/wgpu-core/src/pipeline.rs +++ b/wgpu-core/src/pipeline.rs @@ -3,10 +3,12 @@ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ use crate::{ + binding_model::{BindEntryMap, BindGroupLayoutEntry, BindingType}, device::RenderPassContext, id::{DeviceId, PipelineLayoutId, ShaderModuleId}, LifeGuard, RawString, RefCount, Stored, U32Array, }; +use spirv_headers as spirv; use std::borrow::Borrow; use wgt::{ BufferAddress, ColorStateDescriptor, DepthStencilStateDescriptor, IndexFormat, InputStepMode, @@ -50,6 +52,184 @@ pub struct ProgrammableStageDescriptor { pub entry_point: RawString, } +#[derive(Clone, Debug)] +pub enum BindingError { + /// The binding is missing from the pipeline layout. + Missing, + /// The visibility flags don't include the shader stage. + Invisible, + /// The load/store access flags don't match the shader. + WrongUsage(naga::GlobalUse), + /// The type on the shader side does not match the pipeline binding. + WrongType, + /// The view dimension doesn't match the shader. + WrongTextureViewDimension { dim: spirv::Dim, is_array: bool }, + /// The component type of a sampled texture doesn't match the shader. + WrongTextureComponentType(Option), + /// Texture sampling capability doesn't match with the shader. + WrongTextureSampled, + /// The multisampled flag doesn't match. + WrongTextureMultisampled, +} + +/// Errors produced when validating a programmable stage of a pipeline. +#[derive(Clone, Debug)] +pub enum ProgrammableStageError { + /// Unable to find an entry point matching the specified execution model. + MissingEntryPoint(spirv::ExecutionModel), + /// Error matching a global binding to the pipeline layout. + Binding { + set: u32, + binding: u32, + error: BindingError, + }, +} + +fn validate_binding( + module: &naga::Module, + var: &naga::GlobalVariable, + entry: &BindGroupLayoutEntry, + usage: naga::GlobalUse, +) -> Result<(), BindingError> { + let allowed_usage = match module.types[var.ty].inner { + naga::TypeInner::Struct { .. } => match entry.ty { + BindingType::UniformBuffer => naga::GlobalUse::LOAD, + BindingType::StorageBuffer => naga::GlobalUse::all(), + BindingType::ReadonlyStorageBuffer => naga::GlobalUse::LOAD, + _ => return Err(BindingError::WrongType), + }, + naga::TypeInner::Sampler => match entry.ty { + BindingType::Sampler | BindingType::ComparisonSampler => naga::GlobalUse::empty(), + _ => return Err(BindingError::WrongType), + }, + naga::TypeInner::Image { base, dim, flags } => { + if entry.multisampled != flags.contains(naga::ImageFlags::MULTISAMPLED) { + return Err(BindingError::WrongTextureMultisampled); + } + if flags.contains(naga::ImageFlags::ARRAYED) { + match (dim, entry.view_dimension) { + (spirv::Dim::Dim2D, wgt::TextureViewDimension::D2Array) => (), + (spirv::Dim::DimCube, wgt::TextureViewDimension::CubeArray) => (), + _ => { + return Err(BindingError::WrongTextureViewDimension { + dim, + is_array: true, + }) + } + } + } else { + match (dim, entry.view_dimension) { + (spirv::Dim::Dim1D, wgt::TextureViewDimension::D1) => (), + (spirv::Dim::Dim2D, wgt::TextureViewDimension::D2) => (), + (spirv::Dim::Dim3D, wgt::TextureViewDimension::D3) => (), + (spirv::Dim::DimCube, wgt::TextureViewDimension::Cube) => (), + _ => { + return Err(BindingError::WrongTextureViewDimension { + dim, + is_array: false, + }) + } + } + } + let (allowed_usage, is_sampled) = match entry.ty { + BindingType::SampledTexture => { + let expected_scalar_kind = match entry.texture_component_type { + wgt::TextureComponentType::Float => naga::ScalarKind::Float, + wgt::TextureComponentType::Sint => naga::ScalarKind::Sint, + wgt::TextureComponentType::Uint => naga::ScalarKind::Uint, + }; + match module.types[base].inner { + naga::TypeInner::Scalar { kind, .. } + | naga::TypeInner::Vector { kind, .. } + if kind == expected_scalar_kind => + { + () + } + naga::TypeInner::Scalar { kind, .. } + | naga::TypeInner::Vector { kind, .. } => { + return Err(BindingError::WrongTextureComponentType(Some(kind))) + } + _ => return Err(BindingError::WrongTextureComponentType(None)), + }; + (naga::GlobalUse::LOAD, true) + } + BindingType::ReadonlyStorageTexture => { + //TODO: check entry.storage_texture_format + (naga::GlobalUse::LOAD, false) + } + BindingType::WriteonlyStorageTexture => (naga::GlobalUse::STORE, false), + _ => return Err(BindingError::WrongType), + }; + if is_sampled != flags.contains(naga::ImageFlags::SAMPLED) { + return Err(BindingError::WrongTextureSampled); + } + allowed_usage + } + _ => return Err(BindingError::WrongType), + }; + if allowed_usage.contains(usage) { + Ok(()) + } else { + Err(BindingError::WrongUsage(usage)) + } +} + +pub(crate) fn validate_stage( + module: &naga::Module, + group_layouts: &[&BindEntryMap], + entry_point_name: &str, + execution_model: spirv::ExecutionModel, +) -> Result<(), ProgrammableStageError> { + // Since a shader module can have multiple entry points with the same name, + // we need to look for one with the right execution model. + let entry_point = module + .entry_points + .iter() + .find(|entry_point| { + entry_point.name == entry_point_name && entry_point.exec_model == execution_model + }) + .ok_or(ProgrammableStageError::MissingEntryPoint(execution_model))?; + let stage_bit = match execution_model { + spirv::ExecutionModel::Vertex => wgt::ShaderStage::VERTEX, + spirv::ExecutionModel::Fragment => wgt::ShaderStage::FRAGMENT, + spirv::ExecutionModel::GLCompute => wgt::ShaderStage::COMPUTE, + // the entry point wouldn't match otherwise + _ => unreachable!(), + }; + + let function = &module.functions[entry_point.function]; + for ((_, var), &usage) in module.global_variables.iter().zip(&function.global_usage) { + if usage.is_empty() { + continue; + } + match var.binding { + Some(naga::Binding::Descriptor { set, binding }) => { + let result = group_layouts + .get(set as usize) + .and_then(|map| map.get(&binding)) + .ok_or(BindingError::Missing) + .and_then(|entry| { + if entry.visibility.contains(stage_bit) { + Ok(entry) + } else { + Err(BindingError::Invisible) + } + }) + .and_then(|entry| validate_binding(module, var, entry, usage)); + if let Err(error) = result { + return Err(ProgrammableStageError::Binding { + set, + binding, + error, + }); + } + } + _ => {} //TODO + } + } + Ok(()) +} + #[repr(C)] #[derive(Debug)] pub struct ComputePipelineDescriptor { @@ -57,6 +237,11 @@ pub struct ComputePipelineDescriptor { pub compute_stage: ProgrammableStageDescriptor, } +#[derive(Clone, Debug)] +pub enum ComputePipelineError { + Stage(ProgrammableStageError), +} + #[derive(Debug)] pub struct ComputePipeline { pub(crate) raw: B::ComputePipeline,