//! Infrastructure for testing particular behavior of shaders across platforms. //! //! The tests take the form of a input buffer filled with u32 data. A compute //! shader is run on the input buffer which generates an output buffer. This //! buffer is then read and compared to a given output. use std::{borrow::Cow, fmt::Debug}; use wgpu::{ Backends, BindGroupDescriptor, BindGroupEntry, BindGroupLayoutDescriptor, BindGroupLayoutEntry, BindingType, BufferDescriptor, BufferUsages, CommandEncoderDescriptor, ComputePassDescriptor, ComputePipelineDescriptor, MapMode, PipelineLayoutDescriptor, PollType, PushConstantRange, ShaderModuleDescriptor, ShaderSource, ShaderStages, }; use wgpu_test::TestingContext; pub mod array_size_overrides; pub mod compilation_messages; pub mod data_builtins; pub mod numeric_builtins; pub mod struct_layout; pub mod workgroup_size_overrides; pub mod zero_init_workgroup_mem; #[derive(Clone, Copy, PartialEq)] enum InputStorageType { Uniform, Storage, PushConstant, } impl InputStorageType { fn as_str(&self) -> &'static str { match self { InputStorageType::Uniform => "uniform", InputStorageType::Storage => "storage", InputStorageType::PushConstant => "push_constant", } } } /// Describes a single test of a shader. struct ShaderTest { /// Human readable name name: String, /// Header text. This is arbitrary code injected at the top of the shader. Replaces {{header}} header: String, /// This text will be the body of the `Input` struct. Replaces "{{input_members}}" /// in the shader_test shader. custom_struct_members: String, /// This text will be the body of the compute shader. Replaces "{{body}}" /// in the shader_test shader. body: String, /// This text will be the input type of the compute shader. Replaces "{{input_type}}". /// /// Defaults to "CustomStruct" input_type: String, /// This text will be the output type of the compute shader. Replaces "{{output_type}}". /// /// Defaults to "array". output_type: String, /// List of values will be written to the input buffer. input_values: Vec, /// List of lists of valid expected outputs from the shader. output_values: Vec>, /// Function which compares the output values to the resulting values and /// prints a message on failure. /// /// Defaults [`Self::default_comparison_function`]. output_comparison_fn: fn(&str, &[u32], &[Vec]) -> bool, /// Value to pre-initialize the output buffer to. Often u32::MAX so /// that writing a 0 looks different than not writing a value at all. /// /// Defaults to u32::MAX. output_initialization: u32, /// Which backends this test will fail on. If the test passes on this /// backend when it shouldn't, an assert will be raised. /// /// Defaults to Backends::empty(). failures: Backends, } impl ShaderTest { fn default_comparison_function( test_name: &str, actual_values: &[u32], expected_values: &[Vec], ) -> bool { let cast_actual = bytemuck::cast_slice::(actual_values); // When printing the error message, we want to trim `cast_actual` to the length // of the longest set of expected values. This tracks that value. let mut max_relevant_value_count = 0; for expected in expected_values { let cast_expected = bytemuck::cast_slice::(expected); // We shorten the actual to the length of the expected. if &cast_actual[0..cast_expected.len()] == cast_expected { return true; } max_relevant_value_count = max_relevant_value_count.max(cast_expected.len()); } // We haven't found a match, lets print an error. eprint!( "Inner test failure. Actual {:?}. Expected", &cast_actual[0..max_relevant_value_count] ); if expected_values.len() != 1 { eprint!(" one of: "); } else { eprint!(": "); } for (idx, expected) in expected_values.iter().enumerate() { let cast_expected = bytemuck::cast_slice::(expected); eprint!("{cast_expected:?}"); if idx + 1 != expected_values.len() { eprint!(" "); } } eprintln!(". Test {test_name}"); false } fn new( name: String, custom_struct_members: String, body: String, input_values: &[I], output_values: &[O], ) -> Self { Self { name, header: String::new(), custom_struct_members, body, input_type: String::from("CustomStruct"), output_type: String::from("array"), input_values: bytemuck::cast_slice(input_values).to_vec(), output_values: vec![bytemuck::cast_slice(output_values).to_vec()], output_comparison_fn: Self::default_comparison_function::, output_initialization: u32::MAX, failures: Backends::empty(), } } fn header(mut self, header: String) -> Self { self.header = header; self } fn output_type(mut self, output_type: String) -> Self { self.output_type = output_type; self } /// Add another set of possible outputs. If any of the given /// output values are seen it's considered a success (i.e. this is OR, not AND). /// /// Assumes that this type O is the same as the O provided to new. fn extra_output_values( mut self, output_values: &[O], ) -> Self { self.output_values .push(bytemuck::cast_slice(output_values).to_vec()); self } fn failures(mut self, failures: Backends) -> Self { self.failures = failures; self } } const MAX_BUFFER_SIZE: u64 = 128; /// Runs the given shader tests with the given storage_type for the input_buffer. async fn shader_input_output_test( ctx: TestingContext, storage_type: InputStorageType, tests: Vec, ) { let source = String::from(include_str!("shader_test.wgsl")); let bgl = ctx .device .create_bind_group_layout(&BindGroupLayoutDescriptor { label: None, entries: &[ BindGroupLayoutEntry { binding: 0, visibility: ShaderStages::COMPUTE, ty: BindingType::Buffer { // We don't use this buffer for push constants, but for simplicity // we just use the storage buffer binding. ty: match storage_type { InputStorageType::Uniform => wgpu::BufferBindingType::Uniform, InputStorageType::Storage | InputStorageType::PushConstant => { wgpu::BufferBindingType::Storage { read_only: true } } }, has_dynamic_offset: false, min_binding_size: None, }, count: None, }, BindGroupLayoutEntry { binding: 1, visibility: ShaderStages::COMPUTE, ty: BindingType::Buffer { ty: wgpu::BufferBindingType::Storage { read_only: false }, has_dynamic_offset: false, min_binding_size: None, }, count: None, }, ], }); let input_buffer = ctx.device.create_buffer(&BufferDescriptor { label: Some("input buffer"), size: MAX_BUFFER_SIZE, usage: BufferUsages::COPY_DST | BufferUsages::UNIFORM | BufferUsages::STORAGE, mapped_at_creation: false, }); let output_buffer = ctx.device.create_buffer(&BufferDescriptor { label: Some("output buffer"), size: MAX_BUFFER_SIZE, usage: BufferUsages::COPY_DST | BufferUsages::COPY_SRC | BufferUsages::STORAGE, mapped_at_creation: false, }); let mapping_buffer = ctx.device.create_buffer(&BufferDescriptor { label: Some("mapping buffer"), size: MAX_BUFFER_SIZE, usage: BufferUsages::COPY_DST | BufferUsages::MAP_READ, mapped_at_creation: false, }); let bg = ctx.device.create_bind_group(&BindGroupDescriptor { label: None, layout: &bgl, entries: &[ BindGroupEntry { binding: 0, resource: input_buffer.as_entire_binding(), }, BindGroupEntry { binding: 1, resource: output_buffer.as_entire_binding(), }, ], }); let pll = ctx .device .create_pipeline_layout(&PipelineLayoutDescriptor { label: None, bind_group_layouts: &[&bgl], push_constant_ranges: match storage_type { InputStorageType::PushConstant => &[PushConstantRange { stages: ShaderStages::COMPUTE, range: 0..MAX_BUFFER_SIZE as u32, }], _ => &[], }, }); let mut fail = false; for test in tests { assert!(test.input_values.len() <= MAX_BUFFER_SIZE as usize / 4); assert!(test.output_values.len() <= MAX_BUFFER_SIZE as usize / 4); let test_name = test.name; // -- Building shader + pipeline -- // This isn't terribly efficient but the string is short and it's a test. // The body and input members are the longest part, so do them last. let mut processed = source .replace("{{header}}", &test.header) .replace("{{storage_type}}", storage_type.as_str()) .replace("{{input_type}}", &test.input_type) .replace("{{output_type}}", &test.output_type) .replace("{{input_members}}", &test.custom_struct_members) .replace("{{body}}", &test.body); // Add the bindings for all inputs besides push constants. processed = if matches!(storage_type, InputStorageType::PushConstant) { processed.replace("{{input_bindings}}", "") } else { processed.replace("{{input_bindings}}", "@group(0) @binding(0)") }; let sm = ctx.device.create_shader_module(ShaderModuleDescriptor { label: Some(&format!("shader {test_name}")), source: ShaderSource::Wgsl(Cow::Borrowed(&processed)), }); let pipeline = ctx .device .create_compute_pipeline(&ComputePipelineDescriptor { label: Some(&format!("pipeline {test_name}")), layout: Some(&pll), module: &sm, entry_point: Some("cs_main"), compilation_options: Default::default(), cache: None, }); // -- Initializing data -- let output_pre_init_data = vec![test.output_initialization; MAX_BUFFER_SIZE as usize / 4]; ctx.queue.write_buffer( &output_buffer, 0, bytemuck::cast_slice(&output_pre_init_data), ); match storage_type { InputStorageType::Uniform | InputStorageType::Storage => { ctx.queue .write_buffer(&input_buffer, 0, bytemuck::cast_slice(&test.input_values)); } _ => { // Init happens in the compute pass } } // -- Run test -- let mut encoder = ctx .device .create_command_encoder(&CommandEncoderDescriptor { label: None }); let mut cpass = encoder.begin_compute_pass(&ComputePassDescriptor { label: Some(&format!("cpass {test_name}")), timestamp_writes: None, }); cpass.set_pipeline(&pipeline); cpass.set_bind_group(0, &bg, &[]); if let InputStorageType::PushConstant = storage_type { cpass.set_push_constants(0, bytemuck::cast_slice(&test.input_values)) } cpass.dispatch_workgroups(1, 1, 1); drop(cpass); // -- Pulldown data -- encoder.copy_buffer_to_buffer(&output_buffer, 0, &mapping_buffer, 0, MAX_BUFFER_SIZE); ctx.queue.submit(Some(encoder.finish())); mapping_buffer.slice(..).map_async(MapMode::Read, |_| ()); ctx.async_poll(PollType::wait()).await.unwrap(); let mapped = mapping_buffer.slice(..).get_mapped_range(); let typed: &[u32] = bytemuck::cast_slice(&mapped); // -- Check results -- let failure = !(test.output_comparison_fn)(&test_name, typed, &test.output_values); // We don't immediately panic to let all tests execute if failure != test .failures .contains(ctx.adapter.get_info().backend.into()) { fail |= true; if !failure { eprintln!("Unexpected test success. Test {test_name}"); } } drop(mapped); mapping_buffer.unmap(); } assert!(!fail); }