From 074c0e7191bdb5940125121378050bd4376b9ad1 Mon Sep 17 00:00:00 2001 From: SupaMaggie70Incorporated <85136135+SupaMaggie70Incorporated@users.noreply.github.com> Date: Thu, 24 Jul 2025 19:58:56 -0500 Subject: [PATCH] Add mesh shading api to wgpu & wgpu-core (#7345) --- CHANGELOG.md | 1 + cts_runner/tests/integration.rs | 2 +- examples/README.md | 1 + examples/features/src/lib.rs | 1 + examples/features/src/main.rs | 6 + examples/features/src/mesh_shader/README.md | 9 + examples/features/src/mesh_shader/mod.rs | 142 ++++++ examples/features/src/mesh_shader/shader.frag | 11 + examples/features/src/mesh_shader/shader.mesh | 38 ++ examples/features/src/mesh_shader/shader.task | 16 + .../standalone/custom_backend/src/custom.rs | 7 + player/src/lib.rs | 6 + tests/tests/wgpu-gpu/main.rs | 1 + tests/tests/wgpu-gpu/mesh_shader/basic.frag | 9 + tests/tests/wgpu-gpu/mesh_shader/basic.mesh | 25 + tests/tests/wgpu-gpu/mesh_shader/basic.task | 6 + tests/tests/wgpu-gpu/mesh_shader/mod.rs | 310 ++++++++++++ .../tests/wgpu-gpu/mesh_shader/no-write.frag | 7 + wgpu-core/src/command/bundle.rs | 102 +++- wgpu-core/src/command/draw.rs | 15 + wgpu-core/src/command/mod.rs | 14 +- wgpu-core/src/command/render.rs | 282 ++++++++--- wgpu-core/src/command/render_command.rs | 41 +- wgpu-core/src/device/global.rs | 172 +++++-- wgpu-core/src/device/resource.rs | 442 +++++++++++------- wgpu-core/src/device/trace.rs | 4 + wgpu-core/src/indirect_validation/draw.rs | 6 +- wgpu-core/src/pipeline.rs | 132 +++++- wgpu-core/src/validation.rs | 1 + wgpu-hal/examples/halmark/main.rs | 14 +- wgpu-hal/src/dx12/adapter.rs | 6 + wgpu-hal/src/dx12/device.rs | 23 +- wgpu-hal/src/dynamic/device.rs | 60 +-- wgpu-hal/src/gles/adapter.rs | 6 + wgpu-hal/src/gles/device.rs | 21 +- wgpu-hal/src/lib.rs | 56 +-- wgpu-hal/src/metal/adapter.rs | 6 + wgpu-hal/src/metal/device.rs | 31 +- wgpu-hal/src/noop/mod.rs | 16 +- wgpu-hal/src/vulkan/adapter.rs | 24 +- wgpu-hal/src/vulkan/conv.rs | 6 + wgpu-hal/src/vulkan/device.rs | 301 +++--------- wgpu-info/src/human.rs | 12 + wgpu-types/src/lib.rs | 66 ++- wgpu/src/api/device.rs | 7 + wgpu/src/api/render_pass.rs | 70 +++ wgpu/src/api/render_pipeline.rs | 90 ++++ wgpu/src/backend/webgpu.rs | 45 ++ wgpu/src/backend/wgpu_core.rs | 188 ++++++++ wgpu/src/dispatch.rs | 24 + 50 files changed, 2208 insertions(+), 673 deletions(-) create mode 100644 examples/features/src/mesh_shader/README.md create mode 100644 examples/features/src/mesh_shader/mod.rs create mode 100644 examples/features/src/mesh_shader/shader.frag create mode 100644 examples/features/src/mesh_shader/shader.mesh create mode 100644 examples/features/src/mesh_shader/shader.task create mode 100644 tests/tests/wgpu-gpu/mesh_shader/basic.frag create mode 100644 tests/tests/wgpu-gpu/mesh_shader/basic.mesh create mode 100644 tests/tests/wgpu-gpu/mesh_shader/basic.task create mode 100644 tests/tests/wgpu-gpu/mesh_shader/mod.rs create mode 100644 tests/tests/wgpu-gpu/mesh_shader/no-write.frag diff --git a/CHANGELOG.md b/CHANGELOG.md index 11ffb4533a..7d9558c36b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -157,6 +157,7 @@ By @Vecvec in [#7829](https://github.com/gfx-rs/wgpu/pull/7829). - Add acceleration structure limits. By @Vecvec in [#7845](https://github.com/gfx-rs/wgpu/pull/7845). - Add support for clip-distances feature for Vulkan and GL backends. By @dzamkov in [#7730](https://github.com/gfx-rs/wgpu/pull/7730) - Added `wgpu_types::error::{ErrorType, WebGpuError}` for classification of errors according to WebGPU's [`GPUError`]'s classification scheme, and implement `WebGpuError` for existing errors. This allows users of `wgpu-core` to offload error classification onto the WGPU ecosystem, rather than having to do it themselves without sufficient information. By @ErichDonGubler in [#6547](https://github.com/gfx-rs/wgpu/pull/6547). +- Added mesh shader support to `wgpu`, with examples. Requires passthrough. By @SupaMaggie70Incorporated in [#7345](https://github.com/gfx-rs/wgpu/pull/7345). [`GPUError`]: https://www.w3.org/TR/webgpu/#gpuerror diff --git a/cts_runner/tests/integration.rs b/cts_runner/tests/integration.rs index 199f00e631..bb2118b078 100644 --- a/cts_runner/tests/integration.rs +++ b/cts_runner/tests/integration.rs @@ -34,7 +34,7 @@ impl Display for JsError { impl Debug for JsError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", self) + write!(f, "{self}") } } diff --git a/examples/README.md b/examples/README.md index aa083bc390..a303eb5292 100644 --- a/examples/README.md +++ b/examples/README.md @@ -43,6 +43,7 @@ These examples use a common framework to handle wgpu init, window creation, and - `ray_cube_fragment` - Demonstrates using ray queries with a fragment shader. - `ray_scene` - Demonstrates using ray queries and model loading - `ray_shadows` - Demonstrates a simple use of ray queries - high quality shadows - uses a light set with push constants to raytrace through an untransformed scene and detect whether there is something obstructing the light. +- `mesh_shader` - Rrenders a triangle to a window with mesh shaders, while showcasing most mesh shader related features(task shaders, payloads, per primitive data). #### Compute diff --git a/examples/features/src/lib.rs b/examples/features/src/lib.rs index 479f1d12e9..271241f1ae 100644 --- a/examples/features/src/lib.rs +++ b/examples/features/src/lib.rs @@ -13,6 +13,7 @@ pub mod hello_synchronization; pub mod hello_triangle; pub mod hello_windows; pub mod hello_workgroups; +pub mod mesh_shader; pub mod mipmap; pub mod msaa_line; pub mod multiple_render_targets; diff --git a/examples/features/src/main.rs b/examples/features/src/main.rs index b90c90d26c..c9fc31259c 100644 --- a/examples/features/src/main.rs +++ b/examples/features/src/main.rs @@ -182,6 +182,12 @@ const EXAMPLES: &[ExampleDesc] = &[ webgl: false, // No Ray-tracing extensions webgpu: false, // No Ray-tracing extensions (yet) }, + ExampleDesc { + name: "mesh_shader", + function: wgpu_examples::mesh_shader::main, + webgl: false, + webgpu: false, + }, ]; fn get_example_name() -> Option { diff --git a/examples/features/src/mesh_shader/README.md b/examples/features/src/mesh_shader/README.md new file mode 100644 index 0000000000..9b57d3e490 --- /dev/null +++ b/examples/features/src/mesh_shader/README.md @@ -0,0 +1,9 @@ +# mesh_shader + +This example renders a triangle to a window with mesh shaders, while showcasing most mesh shader related features(task shaders, payloads, per primitive data). + +## To Run + +``` +cargo run --bin wgpu-examples mesh_shader +``` \ No newline at end of file diff --git a/examples/features/src/mesh_shader/mod.rs b/examples/features/src/mesh_shader/mod.rs new file mode 100644 index 0000000000..956722a661 --- /dev/null +++ b/examples/features/src/mesh_shader/mod.rs @@ -0,0 +1,142 @@ +use std::{io::Write, process::Stdio}; + +// Same as in mesh shader tests +fn compile_glsl( + device: &wgpu::Device, + data: &[u8], + shader_stage: &'static str, +) -> wgpu::ShaderModule { + let cmd = std::process::Command::new("glslc") + .args([ + &format!("-fshader-stage={shader_stage}"), + "-", + "-o", + "-", + "--target-env=vulkan1.2", + "--target-spv=spv1.4", + ]) + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .spawn() + .expect("Failed to call glslc"); + cmd.stdin.as_ref().unwrap().write_all(data).unwrap(); + println!("{shader_stage}"); + let output = cmd.wait_with_output().expect("Error waiting for glslc"); + assert!(output.status.success()); + unsafe { + device.create_shader_module_passthrough(wgpu::ShaderModuleDescriptorPassthrough::SpirV( + wgpu::ShaderModuleDescriptorSpirV { + label: None, + source: wgpu::util::make_spirv_raw(&output.stdout), + }, + )) + } +} + +pub struct Example { + pipeline: wgpu::RenderPipeline, +} +impl crate::framework::Example for Example { + fn init( + config: &wgpu::SurfaceConfiguration, + _adapter: &wgpu::Adapter, + device: &wgpu::Device, + _queue: &wgpu::Queue, + ) -> Self { + let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor { + label: None, + bind_group_layouts: &[], + push_constant_ranges: &[], + }); + let (ts, ms, fs) = ( + compile_glsl(device, include_bytes!("shader.task"), "task"), + compile_glsl(device, include_bytes!("shader.mesh"), "mesh"), + compile_glsl(device, include_bytes!("shader.frag"), "frag"), + ); + let pipeline = device.create_mesh_pipeline(&wgpu::MeshPipelineDescriptor { + label: None, + layout: Some(&pipeline_layout), + task: Some(wgpu::TaskState { + module: &ts, + entry_point: Some("main"), + compilation_options: Default::default(), + }), + mesh: wgpu::MeshState { + module: &ms, + entry_point: Some("main"), + compilation_options: Default::default(), + }, + fragment: Some(wgpu::FragmentState { + module: &fs, + entry_point: Some("main"), + compilation_options: Default::default(), + targets: &[Some(config.view_formats[0].into())], + }), + primitive: wgpu::PrimitiveState { + cull_mode: Some(wgpu::Face::Back), + ..Default::default() + }, + depth_stencil: None, + multisample: Default::default(), + multiview: None, + cache: None, + }); + Self { pipeline } + } + fn render(&mut self, view: &wgpu::TextureView, device: &wgpu::Device, queue: &wgpu::Queue) { + let mut encoder = + device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None }); + { + let mut rpass = encoder.begin_render_pass(&wgpu::RenderPassDescriptor { + label: None, + color_attachments: &[Some(wgpu::RenderPassColorAttachment { + view, + resolve_target: None, + ops: wgpu::Operations { + load: wgpu::LoadOp::Clear(wgpu::Color { + r: 0.1, + g: 0.2, + b: 0.3, + a: 1.0, + }), + store: wgpu::StoreOp::Store, + }, + depth_slice: None, + })], + depth_stencil_attachment: None, + timestamp_writes: None, + occlusion_query_set: None, + }); + rpass.push_debug_group("Prepare data for draw."); + rpass.set_pipeline(&self.pipeline); + rpass.pop_debug_group(); + rpass.insert_debug_marker("Draw!"); + rpass.draw_mesh_tasks(1, 1, 1); + } + queue.submit(Some(encoder.finish())); + } + fn required_downlevel_capabilities() -> wgpu::DownlevelCapabilities { + Default::default() + } + fn required_features() -> wgpu::Features { + wgpu::Features::EXPERIMENTAL_MESH_SHADER | wgpu::Features::SPIRV_SHADER_PASSTHROUGH + } + fn required_limits() -> wgpu::Limits { + wgpu::Limits::defaults().using_recommended_minimum_mesh_shader_values() + } + fn resize( + &mut self, + _config: &wgpu::SurfaceConfiguration, + _device: &wgpu::Device, + _queue: &wgpu::Queue, + ) { + // empty + } + fn update(&mut self, _event: winit::event::WindowEvent) { + // empty + } +} + +pub fn main() { + crate::framework::run::("mesh_shader"); +} diff --git a/examples/features/src/mesh_shader/shader.frag b/examples/features/src/mesh_shader/shader.frag new file mode 100644 index 0000000000..49624990f1 --- /dev/null +++ b/examples/features/src/mesh_shader/shader.frag @@ -0,0 +1,11 @@ +#version 450 +#extension GL_EXT_mesh_shader : require + +in VertexInput { layout(location = 0) vec4 color; } +vertexInput; +layout(location = 1) perprimitiveEXT in PrimitiveInput { vec4 colorMask; } +primitiveInput; + +layout(location = 0) out vec4 fragColor; + +void main() { fragColor = vertexInput.color * primitiveInput.colorMask; } \ No newline at end of file diff --git a/examples/features/src/mesh_shader/shader.mesh b/examples/features/src/mesh_shader/shader.mesh new file mode 100644 index 0000000000..7d350e8ce7 --- /dev/null +++ b/examples/features/src/mesh_shader/shader.mesh @@ -0,0 +1,38 @@ +#version 450 +#extension GL_EXT_mesh_shader : require + +const vec4[3] positions = {vec4(0., 1.0, 0., 1.0), vec4(-1.0, -1.0, 0., 1.0), + vec4(1.0, -1.0, 0., 1.0)}; +const vec4[3] colors = {vec4(0., 1., 0., 1.), vec4(0., 0., 1., 1.), + vec4(1., 0., 0., 1.)}; + +// This is an inefficient workgroup size.Ideally the total thread count would be +// a multiple of 64 +layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; +struct PayloadData { + vec4 colorMask; + bool visible; +}; +taskPayloadSharedEXT PayloadData payloadData; + +out VertexOutput { layout(location = 0) vec4 color; } +vertexOutput[]; +layout(location = 1) perprimitiveEXT out PrimitiveOutput { vec4 colorMask; } +primitiveOutput[]; + +shared uint sharedData; + +layout(triangles, max_vertices = 3, max_primitives = 1) out; +void main() { + sharedData = 5; + SetMeshOutputsEXT(3, 1); + gl_MeshVerticesEXT[0].gl_Position = positions[0]; + gl_MeshVerticesEXT[1].gl_Position = positions[1]; + gl_MeshVerticesEXT[2].gl_Position = positions[2]; + vertexOutput[0].color = colors[0] * payloadData.colorMask; + vertexOutput[1].color = colors[1] * payloadData.colorMask; + vertexOutput[2].color = colors[2] * payloadData.colorMask; + gl_PrimitiveTriangleIndicesEXT[gl_LocalInvocationIndex] = uvec3(0, 1, 2); + primitiveOutput[0].colorMask = vec4(1.0, 0.0, 1.0, 1.0); + gl_MeshPrimitivesEXT[0].gl_CullPrimitiveEXT = !payloadData.visible; +} \ No newline at end of file diff --git a/examples/features/src/mesh_shader/shader.task b/examples/features/src/mesh_shader/shader.task new file mode 100644 index 0000000000..6c766bc83a --- /dev/null +++ b/examples/features/src/mesh_shader/shader.task @@ -0,0 +1,16 @@ +#version 450 +#extension GL_EXT_mesh_shader : require + +layout(local_size_x = 4, local_size_y = 1, local_size_z = 1) in; + +struct TaskPayload { + vec4 colorMask; + bool visible; +}; +taskPayloadSharedEXT TaskPayload taskPayload; + +void main() { + taskPayload.colorMask = vec4(1.0, 1.0, 0.0, 1.0); + taskPayload.visible = true; + EmitMeshTasksEXT(3, 1, 1); +} \ No newline at end of file diff --git a/examples/standalone/custom_backend/src/custom.rs b/examples/standalone/custom_backend/src/custom.rs index 8ca7df80bd..6381c48d1d 100644 --- a/examples/standalone/custom_backend/src/custom.rs +++ b/examples/standalone/custom_backend/src/custom.rs @@ -161,6 +161,13 @@ impl DeviceInterface for CustomDevice { unimplemented!() } + fn create_mesh_pipeline( + &self, + _desc: &wgpu::MeshPipelineDescriptor<'_>, + ) -> wgpu::custom::DispatchRenderPipeline { + unimplemented!() + } + fn create_compute_pipeline( &self, desc: &wgpu::ComputePipelineDescriptor<'_>, diff --git a/player/src/lib.rs b/player/src/lib.rs index 24c520c276..5076db0267 100644 --- a/player/src/lib.rs +++ b/player/src/lib.rs @@ -333,6 +333,12 @@ impl GlobalPlay for wgc::global::Global { panic!("{e}"); } } + Action::CreateMeshPipeline { id, desc } => { + let (_, error) = self.device_create_mesh_pipeline(device, &desc, Some(id)); + if let Some(e) = error { + panic!("{e}"); + } + } Action::DestroyRenderPipeline(id) => { self.render_pipeline_drop(id); } diff --git a/tests/tests/wgpu-gpu/main.rs b/tests/tests/wgpu-gpu/main.rs index 8fd24c4343..2e5462f2e6 100644 --- a/tests/tests/wgpu-gpu/main.rs +++ b/tests/tests/wgpu-gpu/main.rs @@ -34,6 +34,7 @@ mod image_atomics; mod instance; mod life_cycle; mod mem_leaks; +mod mesh_shader; mod nv12_texture; mod occlusion_query; mod oob_indexing; diff --git a/tests/tests/wgpu-gpu/mesh_shader/basic.frag b/tests/tests/wgpu-gpu/mesh_shader/basic.frag new file mode 100644 index 0000000000..9d2b777326 --- /dev/null +++ b/tests/tests/wgpu-gpu/mesh_shader/basic.frag @@ -0,0 +1,9 @@ +#version 450 +#extension GL_EXT_mesh_shader : require + +in VertexInput { layout(location = 0) vec4 color; } +vertexInput; + +layout(location = 0) out vec4 fragColor; + +void main() { fragColor = vertexInput.color; } \ No newline at end of file diff --git a/tests/tests/wgpu-gpu/mesh_shader/basic.mesh b/tests/tests/wgpu-gpu/mesh_shader/basic.mesh new file mode 100644 index 0000000000..400cafb36f --- /dev/null +++ b/tests/tests/wgpu-gpu/mesh_shader/basic.mesh @@ -0,0 +1,25 @@ +#version 450 +#extension GL_EXT_mesh_shader : require + +const vec4[3] positions = {vec4(0., 1.0, 0., 1.0), vec4(-1.0, -1.0, 0., 1.0), + vec4(1.0, -1.0, 0., 1.0)}; +const vec4[3] colors = {vec4(0., 1., 0., 1.), vec4(0., 0., 1., 1.), + vec4(1., 0., 0., 1.)}; + +layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; + +out VertexOutput { layout(location = 0) vec4 color; } +vertexOutput[]; + +layout(triangles, max_vertices = 3, max_primitives = 1) out; + +void main() { + SetMeshOutputsEXT(3, 1); + gl_MeshVerticesEXT[0].gl_Position = positions[0]; + gl_MeshVerticesEXT[1].gl_Position = positions[1]; + gl_MeshVerticesEXT[2].gl_Position = positions[2]; + vertexOutput[0].color = colors[0]; + vertexOutput[1].color = colors[1]; + vertexOutput[2].color = colors[2]; + gl_PrimitiveTriangleIndicesEXT[gl_LocalInvocationIndex] = uvec3(0, 1, 2); +} \ No newline at end of file diff --git a/tests/tests/wgpu-gpu/mesh_shader/basic.task b/tests/tests/wgpu-gpu/mesh_shader/basic.task new file mode 100644 index 0000000000..418cffa3f6 --- /dev/null +++ b/tests/tests/wgpu-gpu/mesh_shader/basic.task @@ -0,0 +1,6 @@ +#version 450 +#extension GL_EXT_mesh_shader : require + +layout(local_size_x = 4, local_size_y = 1, local_size_z = 1) in; + +void main() { EmitMeshTasksEXT(1, 1, 1); } \ No newline at end of file diff --git a/tests/tests/wgpu-gpu/mesh_shader/mod.rs b/tests/tests/wgpu-gpu/mesh_shader/mod.rs new file mode 100644 index 0000000000..cd1ca42542 --- /dev/null +++ b/tests/tests/wgpu-gpu/mesh_shader/mod.rs @@ -0,0 +1,310 @@ +use std::{io::Write, process::Stdio}; + +use wgpu::util::DeviceExt; +use wgpu_test::{gpu_test, GpuTestConfiguration, TestParameters, TestingContext}; + +// Same as in mesh shader example +fn compile_glsl( + device: &wgpu::Device, + data: &[u8], + shader_stage: &'static str, +) -> wgpu::ShaderModule { + let cmd = std::process::Command::new("glslc") + .args([ + &format!("-fshader-stage={shader_stage}"), + "-", + "-o", + "-", + "--target-env=vulkan1.2", + "--target-spv=spv1.4", + ]) + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .spawn() + .expect("Failed to call glslc"); + cmd.stdin.as_ref().unwrap().write_all(data).unwrap(); + println!("{shader_stage}"); + let output = cmd.wait_with_output().expect("Error waiting for glslc"); + assert!(output.status.success()); + unsafe { + device.create_shader_module_passthrough(wgpu::ShaderModuleDescriptorPassthrough::SpirV( + wgpu::ShaderModuleDescriptorSpirV { + label: None, + source: wgpu::util::make_spirv_raw(&output.stdout), + }, + )) + } +} + +fn create_depth( + device: &wgpu::Device, +) -> (wgpu::Texture, wgpu::TextureView, wgpu::DepthStencilState) { + let image_size = wgpu::Extent3d { + width: 64, + height: 64, + depth_or_array_layers: 1, + }; + let depth_texture = device.create_texture(&wgpu::TextureDescriptor { + label: None, + size: image_size, + mip_level_count: 1, + sample_count: 1, + dimension: wgpu::TextureDimension::D2, + format: wgpu::TextureFormat::Depth32Float, + usage: wgpu::TextureUsages::RENDER_ATTACHMENT | wgpu::TextureUsages::TEXTURE_BINDING, + view_formats: &[], + }); + let depth_view = depth_texture.create_view(&Default::default()); + let state = wgpu::DepthStencilState { + format: wgpu::TextureFormat::Depth32Float, + depth_write_enabled: true, + depth_compare: wgpu::CompareFunction::Less, // 1. + stencil: wgpu::StencilState::default(), // 2. + bias: wgpu::DepthBiasState::default(), + }; + (depth_texture, depth_view, state) +} + +fn mesh_pipeline_build( + ctx: &TestingContext, + task: Option<&[u8]>, + mesh: &[u8], + frag: Option<&[u8]>, + draw: bool, +) { + let device = &ctx.device; + let (_depth_image, depth_view, depth_state) = create_depth(device); + let task = task.map(|t| compile_glsl(device, t, "task")); + let mesh = compile_glsl(device, mesh, "mesh"); + let frag = frag.map(|f| compile_glsl(device, f, "frag")); + let layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor { + label: None, + bind_group_layouts: &[], + push_constant_ranges: &[], + }); + let pipeline = device.create_mesh_pipeline(&wgpu::MeshPipelineDescriptor { + label: None, + layout: Some(&layout), + task: task.as_ref().map(|task| wgpu::TaskState { + module: task, + entry_point: Some("main"), + compilation_options: Default::default(), + }), + mesh: wgpu::MeshState { + module: &mesh, + entry_point: Some("main"), + compilation_options: Default::default(), + }, + fragment: frag.as_ref().map(|frag| wgpu::FragmentState { + module: frag, + entry_point: Some("main"), + targets: &[], + compilation_options: Default::default(), + }), + primitive: wgpu::PrimitiveState { + cull_mode: Some(wgpu::Face::Back), + ..Default::default() + }, + depth_stencil: Some(depth_state), + multisample: Default::default(), + multiview: None, + cache: None, + }); + if draw { + let mut encoder = + device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None }); + { + let mut pass = encoder.begin_render_pass(&wgpu::RenderPassDescriptor { + label: None, + color_attachments: &[], + depth_stencil_attachment: Some(wgpu::RenderPassDepthStencilAttachment { + view: &depth_view, + depth_ops: Some(wgpu::Operations { + load: wgpu::LoadOp::Clear(1.0), + store: wgpu::StoreOp::Store, + }), + stencil_ops: None, + }), + timestamp_writes: None, + occlusion_query_set: None, + }); + pass.set_pipeline(&pipeline); + pass.draw_mesh_tasks(1, 1, 1); + } + ctx.queue.submit(Some(encoder.finish())); + ctx.device.poll(wgpu::PollType::Wait).unwrap(); + } +} + +#[derive(PartialEq, Eq, Clone, Copy)] +pub enum DrawType { + #[allow(dead_code)] + Standard, + Indirect, + MultiIndirect, + MultiIndirectCount, +} + +fn mesh_draw(ctx: &TestingContext, draw_type: DrawType) { + let device = &ctx.device; + let (_depth_image, depth_view, depth_state) = create_depth(device); + let task = compile_glsl(device, BASIC_TASK, "task"); + let mesh = compile_glsl(device, BASIC_MESH, "mesh"); + let frag = compile_glsl(device, NO_WRITE_FRAG, "frag"); + let layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor { + label: None, + bind_group_layouts: &[], + push_constant_ranges: &[], + }); + let pipeline = device.create_mesh_pipeline(&wgpu::MeshPipelineDescriptor { + label: None, + layout: Some(&layout), + task: Some(wgpu::TaskState { + module: &task, + entry_point: Some("main"), + compilation_options: Default::default(), + }), + mesh: wgpu::MeshState { + module: &mesh, + entry_point: Some("main"), + compilation_options: Default::default(), + }, + fragment: Some(wgpu::FragmentState { + module: &frag, + entry_point: Some("main"), + targets: &[], + compilation_options: Default::default(), + }), + primitive: wgpu::PrimitiveState { + cull_mode: Some(wgpu::Face::Back), + ..Default::default() + }, + depth_stencil: Some(depth_state), + multisample: Default::default(), + multiview: None, + cache: None, + }); + let buffer = match draw_type { + DrawType::Standard => None, + DrawType::Indirect | DrawType::MultiIndirect | DrawType::MultiIndirectCount => Some( + device.create_buffer_init(&wgpu::util::BufferInitDescriptor { + label: None, + usage: wgpu::BufferUsages::INDIRECT, + contents: bytemuck::bytes_of(&[1u32; 4]), + }), + ), + }; + let count_buffer = match draw_type { + DrawType::MultiIndirectCount => Some(device.create_buffer_init( + &wgpu::util::BufferInitDescriptor { + label: None, + usage: wgpu::BufferUsages::INDIRECT, + contents: bytemuck::bytes_of(&[1u32; 1]), + }, + )), + _ => None, + }; + let mut encoder = + device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None }); + { + let mut pass = encoder.begin_render_pass(&wgpu::RenderPassDescriptor { + label: None, + color_attachments: &[], + depth_stencil_attachment: Some(wgpu::RenderPassDepthStencilAttachment { + view: &depth_view, + depth_ops: Some(wgpu::Operations { + load: wgpu::LoadOp::Clear(1.0), + store: wgpu::StoreOp::Store, + }), + stencil_ops: None, + }), + timestamp_writes: None, + occlusion_query_set: None, + }); + pass.set_pipeline(&pipeline); + match draw_type { + DrawType::Standard => pass.draw_mesh_tasks(1, 1, 1), + DrawType::Indirect => pass.draw_mesh_tasks_indirect(buffer.as_ref().unwrap(), 0), + DrawType::MultiIndirect => { + pass.multi_draw_mesh_tasks_indirect(buffer.as_ref().unwrap(), 0, 1) + } + DrawType::MultiIndirectCount => pass.multi_draw_mesh_tasks_indirect_count( + buffer.as_ref().unwrap(), + 0, + count_buffer.as_ref().unwrap(), + 0, + 1, + ), + } + pass.draw_mesh_tasks_indirect(buffer.as_ref().unwrap(), 0); + } + ctx.queue.submit(Some(encoder.finish())); + ctx.device.poll(wgpu::PollType::Wait).unwrap(); +} + +const BASIC_TASK: &[u8] = include_bytes!("basic.task"); +const BASIC_MESH: &[u8] = include_bytes!("basic.mesh"); +//const BASIC_FRAG: &[u8] = include_bytes!("basic.frag.spv"); +const NO_WRITE_FRAG: &[u8] = include_bytes!("no-write.frag"); + +fn default_gpu_test_config(draw_type: DrawType) -> GpuTestConfiguration { + GpuTestConfiguration::new().parameters( + TestParameters::default() + .test_features_limits() + .features( + wgpu::Features::EXPERIMENTAL_MESH_SHADER + | wgpu::Features::SPIRV_SHADER_PASSTHROUGH + | match draw_type { + DrawType::Standard | DrawType::Indirect => wgpu::Features::empty(), + DrawType::MultiIndirect => wgpu::Features::MULTI_DRAW_INDIRECT, + DrawType::MultiIndirectCount => wgpu::Features::MULTI_DRAW_INDIRECT_COUNT, + }, + ) + .limits(wgpu::Limits::default().using_recommended_minimum_mesh_shader_values()), + ) +} + +// Mesh pipeline configs +#[gpu_test] +static MESH_PIPELINE_BASIC_MESH: GpuTestConfiguration = default_gpu_test_config(DrawType::Standard) + .run_sync(|ctx| { + mesh_pipeline_build(&ctx, None, BASIC_MESH, None, true); + }); +#[gpu_test] +static MESH_PIPELINE_BASIC_TASK_MESH: GpuTestConfiguration = + default_gpu_test_config(DrawType::Standard).run_sync(|ctx| { + mesh_pipeline_build(&ctx, Some(BASIC_TASK), BASIC_MESH, None, true); + }); +#[gpu_test] +static MESH_PIPELINE_BASIC_MESH_FRAG: GpuTestConfiguration = + default_gpu_test_config(DrawType::Standard).run_sync(|ctx| { + mesh_pipeline_build(&ctx, None, BASIC_MESH, Some(NO_WRITE_FRAG), true); + }); +#[gpu_test] +static MESH_PIPELINE_BASIC_TASK_MESH_FRAG: GpuTestConfiguration = + default_gpu_test_config(DrawType::Standard).run_sync(|ctx| { + mesh_pipeline_build( + &ctx, + Some(BASIC_TASK), + BASIC_MESH, + Some(NO_WRITE_FRAG), + true, + ); + }); + +// Mesh draw +#[gpu_test] +static MESH_DRAW_INDIRECT: GpuTestConfiguration = default_gpu_test_config(DrawType::Indirect) + .run_sync(|ctx| { + mesh_draw(&ctx, DrawType::Indirect); + }); +#[gpu_test] +static MESH_MULTI_DRAW_INDIRECT: GpuTestConfiguration = + default_gpu_test_config(DrawType::MultiIndirect).run_sync(|ctx| { + mesh_draw(&ctx, DrawType::MultiIndirect); + }); +#[gpu_test] +static MESH_MULTI_DRAW_INDIRECT_COUNT: GpuTestConfiguration = + default_gpu_test_config(DrawType::MultiIndirectCount).run_sync(|ctx| { + mesh_draw(&ctx, DrawType::MultiIndirectCount); + }); diff --git a/tests/tests/wgpu-gpu/mesh_shader/no-write.frag b/tests/tests/wgpu-gpu/mesh_shader/no-write.frag new file mode 100644 index 0000000000..d0512bb0fa --- /dev/null +++ b/tests/tests/wgpu-gpu/mesh_shader/no-write.frag @@ -0,0 +1,7 @@ +#version 450 +#extension GL_EXT_mesh_shader : require + +in VertexInput { layout(location = 0) vec4 color; } +vertexInput; + +void main() {} \ No newline at end of file diff --git a/wgpu-core/src/command/bundle.rs b/wgpu-core/src/command/bundle.rs index 09f522f001..8e8bf90b14 100644 --- a/wgpu-core/src/command/bundle.rs +++ b/wgpu-core/src/command/bundle.rs @@ -123,7 +123,7 @@ use crate::{ use super::{ pass, render_command::{ArcRenderCommand, RenderCommand}, - DrawKind, + DrawCommandFamily, DrawKind, }; /// Describes a [`RenderBundleEncoder`]. @@ -380,7 +380,7 @@ impl RenderBundleEncoder { } => { let scope = PassErrorScope::Draw { kind: DrawKind::Draw, - indexed: false, + family: DrawCommandFamily::Draw, }; draw( &mut state, @@ -401,7 +401,7 @@ impl RenderBundleEncoder { } => { let scope = PassErrorScope::Draw { kind: DrawKind::Draw, - indexed: true, + family: DrawCommandFamily::DrawIndexed, }; draw_indexed( &mut state, @@ -414,15 +414,33 @@ impl RenderBundleEncoder { ) .map_pass_err(scope)?; } + RenderCommand::DrawMeshTasks { + group_count_x, + group_count_y, + group_count_z, + } => { + let scope = PassErrorScope::Draw { + kind: DrawKind::Draw, + family: DrawCommandFamily::DrawMeshTasks, + }; + draw_mesh_tasks( + &mut state, + &base.dynamic_offsets, + group_count_x, + group_count_y, + group_count_z, + ) + .map_pass_err(scope)?; + } RenderCommand::DrawIndirect { buffer_id, offset, count: 1, - indexed, + family, } => { let scope = PassErrorScope::Draw { kind: DrawKind::DrawIndirect, - indexed, + family, }; multi_draw_indirect( &mut state, @@ -430,7 +448,7 @@ impl RenderBundleEncoder { &buffer_guard, buffer_id, offset, - indexed, + family, ) .map_pass_err(scope)?; } @@ -787,13 +805,48 @@ fn draw_indexed( Ok(()) } +fn draw_mesh_tasks( + state: &mut State, + dynamic_offsets: &[u32], + group_count_x: u32, + group_count_y: u32, + group_count_z: u32, +) -> Result<(), RenderBundleErrorInner> { + let pipeline = state.pipeline()?; + let used_bind_groups = pipeline.used_bind_groups; + + let groups_size_limit = state.device.limits.max_task_workgroups_per_dimension; + let max_groups = state.device.limits.max_task_workgroup_total_count; + if group_count_x > groups_size_limit + || group_count_y > groups_size_limit + || group_count_z > groups_size_limit + || group_count_x * group_count_y * group_count_z > max_groups + { + return Err(RenderBundleErrorInner::Draw(DrawError::InvalidGroupSize { + current: [group_count_x, group_count_y, group_count_z], + limit: groups_size_limit, + max_total: max_groups, + })); + } + + if group_count_x > 0 && group_count_y > 0 && group_count_z > 0 { + state.flush_binds(used_bind_groups, dynamic_offsets); + state.commands.push(ArcRenderCommand::DrawMeshTasks { + group_count_x, + group_count_y, + group_count_z, + }); + } + Ok(()) +} + fn multi_draw_indirect( state: &mut State, dynamic_offsets: &[u32], buffer_guard: &crate::storage::Storage>, buffer_id: id::Id, offset: u64, - indexed: bool, + family: DrawCommandFamily, ) -> Result<(), RenderBundleErrorInner> { state .device @@ -809,7 +862,7 @@ fn multi_draw_indirect( let vertex_limits = super::VertexLimits::new(state.vertex_buffer_sizes(), &pipeline.steps); - let stride = super::get_stride_of_indirect_args(indexed); + let stride = super::get_stride_of_indirect_args(family); state .buffer_memory_init_actions .extend(buffer.initialization_status.read().create_action( @@ -818,7 +871,7 @@ fn multi_draw_indirect( MemoryInitKind::NeedsInitializedMemory, )); - let vertex_or_index_limit = if indexed { + let vertex_or_index_limit = if family == DrawCommandFamily::DrawIndexed { let index = match state.index { Some(ref mut index) => index, None => return Err(DrawError::MissingIndexBuffer.into()), @@ -844,7 +897,7 @@ fn multi_draw_indirect( buffer, offset, count: 1, - indexed, + family, vertex_or_index_limit, instance_limit, @@ -1066,11 +1119,18 @@ impl RenderBundle { ) }; } + Cmd::DrawMeshTasks { + group_count_x, + group_count_y, + group_count_z, + } => unsafe { + raw.draw_mesh_tasks(*group_count_x, *group_count_y, *group_count_z); + }, Cmd::DrawIndirect { buffer, offset, count: 1, - indexed, + family, vertex_or_index_limit, instance_limit, @@ -1081,7 +1141,7 @@ impl RenderBundle { &self.device, buffer, *offset, - *indexed, + *family, *vertex_or_index_limit, *instance_limit, )?; @@ -1092,10 +1152,14 @@ impl RenderBundle { } else { (buffer.try_raw(snatch_guard)?, *offset) }; - if *indexed { - unsafe { raw.draw_indexed_indirect(buffer, offset, 1) }; - } else { - unsafe { raw.draw_indirect(buffer, offset, 1) }; + match family { + DrawCommandFamily::Draw => unsafe { raw.draw_indirect(buffer, offset, 1) }, + DrawCommandFamily::DrawIndexed => unsafe { + raw.draw_indexed_indirect(buffer, offset, 1) + }, + DrawCommandFamily::DrawMeshTasks => unsafe { + raw.draw_mesh_tasks_indirect(buffer, offset, 1); + }, } } Cmd::DrawIndirect { .. } | Cmd::MultiDrawIndirectCount { .. } => { @@ -1597,7 +1661,7 @@ where pub mod bundle_ffi { use super::{RenderBundleEncoder, RenderCommand}; - use crate::{id, RawString}; + use crate::{command::DrawCommandFamily, id, RawString}; use core::{convert::TryInto, slice}; use wgt::{BufferAddress, BufferSize, DynamicOffset, IndexFormat}; @@ -1752,7 +1816,7 @@ pub mod bundle_ffi { buffer_id, offset, count: 1, - indexed: false, + family: DrawCommandFamily::Draw, }); } @@ -1765,7 +1829,7 @@ pub mod bundle_ffi { buffer_id, offset, count: 1, - indexed: true, + family: DrawCommandFamily::DrawIndexed, }); } diff --git a/wgpu-core/src/command/draw.rs b/wgpu-core/src/command/draw.rs index adb83643d9..7a57077a4f 100644 --- a/wgpu-core/src/command/draw.rs +++ b/wgpu-core/src/command/draw.rs @@ -56,6 +56,21 @@ pub enum DrawError { }, #[error(transparent)] BindingSizeTooSmall(#[from] LateMinBufferBindingSizeMismatch), + + #[error( + "Wrong pipeline type for this draw command. Attempted to call {} draw command on {} pipeline", + if *wanted_mesh_pipeline {"mesh shader"} else {"standard"}, + if *wanted_mesh_pipeline {"standard"} else {"mesh shader"}, + )] + WrongPipelineType { wanted_mesh_pipeline: bool }, + #[error( + "Each current draw group size dimension ({current:?}) must be less or equal to {limit}, and the product must be less or equal to {max_total}" + )] + InvalidGroupSize { + current: [u32; 3], + limit: u32, + max_total: u32, + }, } impl WebGpuError for DrawError { diff --git a/wgpu-core/src/command/mod.rs b/wgpu-core/src/command/mod.rs index 354a6ebac7..3abcf8307c 100644 --- a/wgpu-core/src/command/mod.rs +++ b/wgpu-core/src/command/mod.rs @@ -1442,6 +1442,15 @@ pub enum DrawKind { MultiDrawIndirectCount, } +/// The type of draw command(indexed or not, or mesh shader) +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum DrawCommandFamily { + Draw, + DrawIndexed, + DrawMeshTasks, +} + /// A command that can be recorded in a pass or bundle. /// /// This is used to provide context for errors during command recording. @@ -1480,7 +1489,10 @@ pub enum PassErrorScope { #[error("In a set_scissor_rect command")] SetScissorRect, #[error("In a draw command, kind: {kind:?}")] - Draw { kind: DrawKind, indexed: bool }, + Draw { + kind: DrawKind, + family: DrawCommandFamily, + }, #[error("In a write_timestamp command")] WriteTimestamp, #[error("In a begin_occlusion_query command")] diff --git a/wgpu-core/src/command/render.rs b/wgpu-core/src/command/render.rs index aef81bc4a0..ba5fdbcd7d 100644 --- a/wgpu-core/src/command/render.rs +++ b/wgpu-core/src/command/render.rs @@ -54,7 +54,7 @@ use super::{ memory_init::TextureSurfaceDiscard, CommandBufferTextureMemoryActions, CommandEncoder, QueryResetMap, }; -use super::{DrawKind, Rect}; +use super::{DrawCommandFamily, DrawKind, Rect}; use crate::binding_model::{BindError, PushConstantUploadError}; pub use wgt::{LoadOp, StoreOp}; @@ -513,7 +513,7 @@ struct State<'scope, 'snatch_guard, 'cmd_enc, 'raw_encoder> { impl<'scope, 'snatch_guard, 'cmd_enc, 'raw_encoder> State<'scope, 'snatch_guard, 'cmd_enc, 'raw_encoder> { - fn is_ready(&self, indexed: bool) -> Result<(), DrawError> { + fn is_ready(&self, family: DrawCommandFamily) -> Result<(), DrawError> { if let Some(pipeline) = self.pipeline.as_ref() { self.general.binder.check_compatibility(pipeline.as_ref())?; self.general.binder.check_late_buffer_bindings()?; @@ -537,7 +537,7 @@ impl<'scope, 'snatch_guard, 'cmd_enc, 'raw_encoder> }); } - if indexed { + if family == DrawCommandFamily::DrawIndexed { // Pipeline expects an index buffer if let Some(pipeline_index_format) = pipeline.strip_index_format { // We have a buffer bound @@ -556,6 +556,11 @@ impl<'scope, 'snatch_guard, 'cmd_enc, 'raw_encoder> } } } + if (family == DrawCommandFamily::DrawMeshTasks) != pipeline.is_mesh { + return Err(DrawError::WrongPipelineType { + wanted_mesh_pipeline: !pipeline.is_mesh, + }); + } Ok(()) } else { Err(DrawError::MissingPipeline(pass::MissingPipeline)) @@ -2013,7 +2018,7 @@ impl Global { } => { let scope = PassErrorScope::Draw { kind: DrawKind::Draw, - indexed: false, + family: DrawCommandFamily::Draw, }; draw( &mut state, @@ -2033,7 +2038,7 @@ impl Global { } => { let scope = PassErrorScope::Draw { kind: DrawKind::Draw, - indexed: true, + family: DrawCommandFamily::DrawIndexed, }; draw_indexed( &mut state, @@ -2045,11 +2050,28 @@ impl Global { ) .map_pass_err(scope)?; } + ArcRenderCommand::DrawMeshTasks { + group_count_x, + group_count_y, + group_count_z, + } => { + let scope = PassErrorScope::Draw { + kind: DrawKind::Draw, + family: DrawCommandFamily::DrawMeshTasks, + }; + draw_mesh_tasks( + &mut state, + group_count_x, + group_count_y, + group_count_z, + ) + .map_pass_err(scope)?; + } ArcRenderCommand::DrawIndirect { buffer, offset, count, - indexed, + family, vertex_or_index_limit: _, instance_limit: _, @@ -2060,7 +2082,7 @@ impl Global { } else { DrawKind::DrawIndirect }, - indexed, + family, }; multi_draw_indirect( &mut state, @@ -2070,7 +2092,7 @@ impl Global { buffer, offset, count, - indexed, + family, ) .map_pass_err(scope)?; } @@ -2080,11 +2102,11 @@ impl Global { count_buffer, count_buffer_offset, max_count, - indexed, + family, } => { let scope = PassErrorScope::Draw { kind: DrawKind::MultiDrawIndirectCount, - indexed, + family, }; multi_draw_indirect_count( &mut state, @@ -2094,7 +2116,7 @@ impl Global { count_buffer, count_buffer_offset, max_count, - indexed, + family, ) .map_pass_err(scope)?; } @@ -2545,7 +2567,7 @@ fn draw( ) -> Result<(), DrawError> { api_log!("RenderPass::draw {vertex_count} {instance_count} {first_vertex} {first_instance}"); - state.is_ready(false)?; + state.is_ready(DrawCommandFamily::Draw)?; state .vertex @@ -2579,7 +2601,7 @@ fn draw_indexed( ) -> Result<(), DrawError> { api_log!("RenderPass::draw_indexed {index_count} {instance_count} {first_index} {base_vertex} {first_instance}"); - state.is_ready(true)?; + state.is_ready(DrawCommandFamily::DrawIndexed)?; let last_index = first_index as u64 + index_count as u64; let index_limit = state.index.limit; @@ -2608,6 +2630,45 @@ fn draw_indexed( Ok(()) } +fn draw_mesh_tasks( + state: &mut State, + group_count_x: u32, + group_count_y: u32, + group_count_z: u32, +) -> Result<(), DrawError> { + api_log!("RenderPass::draw_mesh_tasks {group_count_x} {group_count_y} {group_count_z}"); + + state.is_ready(DrawCommandFamily::DrawMeshTasks)?; + + let groups_size_limit = state + .general + .device + .limits + .max_task_workgroups_per_dimension; + let max_groups = state.general.device.limits.max_task_workgroup_total_count; + if group_count_x > groups_size_limit + || group_count_y > groups_size_limit + || group_count_z > groups_size_limit + || group_count_x * group_count_y * group_count_z > max_groups + { + return Err(DrawError::InvalidGroupSize { + current: [group_count_x, group_count_y, group_count_z], + limit: groups_size_limit, + max_total: max_groups, + }); + } + + unsafe { + if group_count_x > 0 && group_count_y > 0 && group_count_z > 0 { + state + .general + .raw_encoder + .draw_mesh_tasks(group_count_x, group_count_y, group_count_z); + } + } + Ok(()) +} + fn multi_draw_indirect( state: &mut State, indirect_draw_validation_resources: &mut crate::indirect_validation::DrawResources, @@ -2616,14 +2677,14 @@ fn multi_draw_indirect( indirect_buffer: Arc, offset: u64, count: u32, - indexed: bool, + family: DrawCommandFamily, ) -> Result<(), RenderPassErrorInner> { api_log!( - "RenderPass::draw_indirect (indexed:{indexed}) {} {offset} {count:?}", + "RenderPass::draw_indirect (family:{family:?}) {} {offset} {count:?}", indirect_buffer.error_ident() ); - state.is_ready(indexed)?; + state.is_ready(family)?; if count != 1 { state @@ -2645,7 +2706,7 @@ fn multi_draw_indirect( return Err(RenderPassErrorInner::UnalignedIndirectBufferOffset(offset)); } - let stride = get_stride_of_indirect_args(indexed); + let stride = get_stride_of_indirect_args(family); let end_offset = offset + stride * count as u64; if end_offset > indirect_buffer.size { @@ -2667,18 +2728,21 @@ fn multi_draw_indirect( fn draw( raw_encoder: &mut dyn hal::DynCommandEncoder, - indexed: bool, + family: DrawCommandFamily, indirect_buffer: &dyn hal::DynBuffer, offset: u64, count: u32, ) { - match indexed { - false => unsafe { + match family { + DrawCommandFamily::Draw => unsafe { raw_encoder.draw_indirect(indirect_buffer, offset, count); }, - true => unsafe { + DrawCommandFamily::DrawIndexed => unsafe { raw_encoder.draw_indexed_indirect(indirect_buffer, offset, count); }, + DrawCommandFamily::DrawMeshTasks => unsafe { + raw_encoder.draw_mesh_tasks_indirect(indirect_buffer, offset, count); + }, } } @@ -2703,7 +2767,7 @@ fn multi_draw_indirect( indirect_draw_validation_batcher: &'a mut crate::indirect_validation::DrawBatcher, indirect_buffer: Arc, - indexed: bool, + family: DrawCommandFamily, vertex_or_index_limit: u64, instance_limit: u64, } @@ -2715,7 +2779,7 @@ fn multi_draw_indirect( self.device, &self.indirect_buffer, offset, - self.indexed, + self.family, self.vertex_or_index_limit, self.instance_limit, )?; @@ -2731,7 +2795,7 @@ fn multi_draw_indirect( .get_dst_buffer(draw_data.buffer_index); draw( self.raw_encoder, - self.indexed, + self.family, dst_buffer, draw_data.offset, draw_data.count, @@ -2745,8 +2809,8 @@ fn multi_draw_indirect( indirect_draw_validation_resources, indirect_draw_validation_batcher, indirect_buffer, - indexed, - vertex_or_index_limit: if indexed { + family, + vertex_or_index_limit: if family == DrawCommandFamily::DrawIndexed { state.index.limit } else { state.vertex.limits.vertex_limit @@ -2781,7 +2845,7 @@ fn multi_draw_indirect( draw( state.general.raw_encoder, - indexed, + family, indirect_buffer.try_raw(state.general.snatch_guard)?, offset, count, @@ -2799,17 +2863,17 @@ fn multi_draw_indirect_count( count_buffer: Arc, count_buffer_offset: u64, max_count: u32, - indexed: bool, + family: DrawCommandFamily, ) -> Result<(), RenderPassErrorInner> { api_log!( - "RenderPass::multi_draw_indirect_count (indexed:{indexed}) {} {offset} {} {count_buffer_offset:?} {max_count:?}", + "RenderPass::multi_draw_indirect_count (family:{family:?}) {} {offset} {} {count_buffer_offset:?} {max_count:?}", indirect_buffer.error_ident(), count_buffer.error_ident() ); - state.is_ready(indexed)?; + state.is_ready(family)?; - let stride = get_stride_of_indirect_args(indexed); + let stride = get_stride_of_indirect_args(family); state .general @@ -2879,8 +2943,8 @@ fn multi_draw_indirect_count( ), ); - match indexed { - false => unsafe { + match family { + DrawCommandFamily::Draw => unsafe { state.general.raw_encoder.draw_indirect_count( indirect_raw, offset, @@ -2889,7 +2953,7 @@ fn multi_draw_indirect_count( max_count, ); }, - true => unsafe { + DrawCommandFamily::DrawIndexed => unsafe { state.general.raw_encoder.draw_indexed_indirect_count( indirect_raw, offset, @@ -2898,6 +2962,15 @@ fn multi_draw_indirect_count( max_count, ); }, + DrawCommandFamily::DrawMeshTasks => unsafe { + state.general.raw_encoder.draw_mesh_tasks_indirect_count( + indirect_raw, + offset, + count_raw, + count_buffer_offset, + max_count, + ); + }, } Ok(()) } @@ -3250,7 +3323,7 @@ impl Global { ) -> Result<(), PassStateError> { let scope = PassErrorScope::Draw { kind: DrawKind::Draw, - indexed: false, + family: DrawCommandFamily::Draw, }; let base = pass_base!(pass, scope); @@ -3275,7 +3348,7 @@ impl Global { ) -> Result<(), PassStateError> { let scope = PassErrorScope::Draw { kind: DrawKind::Draw, - indexed: true, + family: DrawCommandFamily::DrawIndexed, }; let base = pass_base!(pass, scope); @@ -3290,6 +3363,27 @@ impl Global { Ok(()) } + pub fn render_pass_draw_mesh_tasks( + &self, + pass: &mut RenderPass, + group_count_x: u32, + group_count_y: u32, + group_count_z: u32, + ) -> Result<(), RenderPassError> { + let scope = PassErrorScope::Draw { + kind: DrawKind::Draw, + family: DrawCommandFamily::DrawMeshTasks, + }; + let base = pass_base!(pass, scope); + + base.commands.push(ArcRenderCommand::DrawMeshTasks { + group_count_x, + group_count_y, + group_count_z, + }); + Ok(()) + } + pub fn render_pass_draw_indirect( &self, pass: &mut RenderPass, @@ -3298,7 +3392,7 @@ impl Global { ) -> Result<(), PassStateError> { let scope = PassErrorScope::Draw { kind: DrawKind::DrawIndirect, - indexed: false, + family: DrawCommandFamily::Draw, }; let base = pass_base!(pass, scope); @@ -3306,7 +3400,7 @@ impl Global { buffer: pass_try!(base, scope, self.resolve_render_pass_buffer_id(buffer_id)), offset, count: 1, - indexed: false, + family: DrawCommandFamily::Draw, vertex_or_index_limit: 0, instance_limit: 0, @@ -3323,7 +3417,7 @@ impl Global { ) -> Result<(), PassStateError> { let scope = PassErrorScope::Draw { kind: DrawKind::DrawIndirect, - indexed: true, + family: DrawCommandFamily::DrawIndexed, }; let base = pass_base!(pass, scope); @@ -3331,7 +3425,32 @@ impl Global { buffer: pass_try!(base, scope, self.resolve_render_pass_buffer_id(buffer_id)), offset, count: 1, - indexed: true, + family: DrawCommandFamily::DrawIndexed, + + vertex_or_index_limit: 0, + instance_limit: 0, + }); + + Ok(()) + } + + pub fn render_pass_draw_mesh_tasks_indirect( + &self, + pass: &mut RenderPass, + buffer_id: id::BufferId, + offset: BufferAddress, + ) -> Result<(), RenderPassError> { + let scope = PassErrorScope::Draw { + kind: DrawKind::DrawIndirect, + family: DrawCommandFamily::DrawMeshTasks, + }; + let base = pass_base!(pass, scope); + + base.commands.push(ArcRenderCommand::DrawIndirect { + buffer: pass_try!(base, scope, self.resolve_render_pass_buffer_id(buffer_id)), + offset, + count: 1, + family: DrawCommandFamily::DrawMeshTasks, vertex_or_index_limit: 0, instance_limit: 0, @@ -3349,7 +3468,7 @@ impl Global { ) -> Result<(), PassStateError> { let scope = PassErrorScope::Draw { kind: DrawKind::MultiDrawIndirect, - indexed: false, + family: DrawCommandFamily::Draw, }; let base = pass_base!(pass, scope); @@ -3357,7 +3476,7 @@ impl Global { buffer: pass_try!(base, scope, self.resolve_render_pass_buffer_id(buffer_id)), offset, count, - indexed: false, + family: DrawCommandFamily::Draw, vertex_or_index_limit: 0, instance_limit: 0, @@ -3375,7 +3494,7 @@ impl Global { ) -> Result<(), PassStateError> { let scope = PassErrorScope::Draw { kind: DrawKind::MultiDrawIndirect, - indexed: true, + family: DrawCommandFamily::DrawIndexed, }; let base = pass_base!(pass, scope); @@ -3383,7 +3502,33 @@ impl Global { buffer: pass_try!(base, scope, self.resolve_render_pass_buffer_id(buffer_id)), offset, count, - indexed: true, + family: DrawCommandFamily::DrawIndexed, + + vertex_or_index_limit: 0, + instance_limit: 0, + }); + + Ok(()) + } + + pub fn render_pass_multi_draw_mesh_tasks_indirect( + &self, + pass: &mut RenderPass, + buffer_id: id::BufferId, + offset: BufferAddress, + count: u32, + ) -> Result<(), RenderPassError> { + let scope = PassErrorScope::Draw { + kind: DrawKind::MultiDrawIndirect, + family: DrawCommandFamily::DrawMeshTasks, + }; + let base = pass_base!(pass, scope); + + base.commands.push(ArcRenderCommand::DrawIndirect { + buffer: pass_try!(base, scope, self.resolve_render_pass_buffer_id(buffer_id)), + offset, + count, + family: DrawCommandFamily::DrawMeshTasks, vertex_or_index_limit: 0, instance_limit: 0, @@ -3403,7 +3548,7 @@ impl Global { ) -> Result<(), PassStateError> { let scope = PassErrorScope::Draw { kind: DrawKind::MultiDrawIndirectCount, - indexed: false, + family: DrawCommandFamily::Draw, }; let base = pass_base!(pass, scope); @@ -3418,7 +3563,7 @@ impl Global { ), count_buffer_offset, max_count, - indexed: false, + family: DrawCommandFamily::Draw, }); Ok(()) @@ -3435,7 +3580,7 @@ impl Global { ) -> Result<(), PassStateError> { let scope = PassErrorScope::Draw { kind: DrawKind::MultiDrawIndirectCount, - indexed: true, + family: DrawCommandFamily::DrawIndexed, }; let base = pass_base!(pass, scope); @@ -3450,7 +3595,39 @@ impl Global { ), count_buffer_offset, max_count, - indexed: true, + family: DrawCommandFamily::DrawIndexed, + }); + + Ok(()) + } + + pub fn render_pass_multi_draw_mesh_tasks_indirect_count( + &self, + pass: &mut RenderPass, + buffer_id: id::BufferId, + offset: BufferAddress, + count_buffer_id: id::BufferId, + count_buffer_offset: BufferAddress, + max_count: u32, + ) -> Result<(), RenderPassError> { + let scope = PassErrorScope::Draw { + kind: DrawKind::MultiDrawIndirectCount, + family: DrawCommandFamily::DrawMeshTasks, + }; + let base = pass_base!(pass, scope); + + base.commands + .push(ArcRenderCommand::MultiDrawIndirectCount { + buffer: pass_try!(base, scope, self.resolve_render_pass_buffer_id(buffer_id)), + offset, + count_buffer: pass_try!( + base, + scope, + self.resolve_render_pass_buffer_id(count_buffer_id) + ), + count_buffer_offset, + max_count, + family: DrawCommandFamily::DrawMeshTasks, }); Ok(()) @@ -3607,9 +3784,10 @@ impl Global { } } -pub(crate) const fn get_stride_of_indirect_args(indexed: bool) -> u64 { - match indexed { - false => size_of::() as u64, - true => size_of::() as u64, +pub(crate) const fn get_stride_of_indirect_args(family: DrawCommandFamily) -> u64 { + match family { + DrawCommandFamily::Draw => size_of::() as u64, + DrawCommandFamily::DrawIndexed => size_of::() as u64, + DrawCommandFamily::DrawMeshTasks => size_of::() as u64, } } diff --git a/wgpu-core/src/command/render_command.rs b/wgpu-core/src/command/render_command.rs index 6564238548..f2972f9949 100644 --- a/wgpu-core/src/command/render_command.rs +++ b/wgpu-core/src/command/render_command.rs @@ -2,7 +2,7 @@ use alloc::sync::Arc; use wgt::{BufferAddress, BufferSize, Color}; -use super::{Rect, RenderBundle}; +use super::{DrawCommandFamily, Rect, RenderBundle}; use crate::{ binding_model::BindGroup, id, @@ -82,11 +82,16 @@ pub enum RenderCommand { base_vertex: i32, first_instance: u32, }, + DrawMeshTasks { + group_count_x: u32, + group_count_y: u32, + group_count_z: u32, + }, DrawIndirect { buffer_id: id::BufferId, offset: BufferAddress, count: u32, - indexed: bool, + family: DrawCommandFamily, }, MultiDrawIndirectCount { buffer_id: id::BufferId, @@ -94,7 +99,7 @@ pub enum RenderCommand { count_buffer_id: id::BufferId, count_buffer_offset: BufferAddress, max_count: u32, - indexed: bool, + family: DrawCommandFamily, }, PushDebugGroup { color: u32, @@ -310,12 +315,21 @@ impl RenderCommand { base_vertex, first_instance, }, + RenderCommand::DrawMeshTasks { + group_count_x, + group_count_y, + group_count_z, + } => ArcRenderCommand::DrawMeshTasks { + group_count_x, + group_count_y, + group_count_z, + }, RenderCommand::DrawIndirect { buffer_id, offset, count, - indexed, + family, } => ArcRenderCommand::DrawIndirect { buffer: buffers_guard.get(buffer_id).get().map_err(|e| { RenderPassError { @@ -325,14 +339,14 @@ impl RenderCommand { } else { DrawKind::DrawIndirect }, - indexed, + family, }, inner: e.into(), } })?, offset, count, - indexed, + family, vertex_or_index_limit: 0, instance_limit: 0, @@ -344,11 +358,11 @@ impl RenderCommand { count_buffer_id, count_buffer_offset, max_count, - indexed, + family, } => { let scope = PassErrorScope::Draw { kind: DrawKind::MultiDrawIndirectCount, - indexed, + family, }; ArcRenderCommand::MultiDrawIndirectCount { buffer: buffers_guard.get(buffer_id).get().map_err(|e| { @@ -366,7 +380,7 @@ impl RenderCommand { )?, count_buffer_offset, max_count, - indexed, + family, } } @@ -473,11 +487,16 @@ pub enum ArcRenderCommand { base_vertex: i32, first_instance: u32, }, + DrawMeshTasks { + group_count_x: u32, + group_count_y: u32, + group_count_z: u32, + }, DrawIndirect { buffer: Arc, offset: BufferAddress, count: u32, - indexed: bool, + family: DrawCommandFamily, /// This limit is only populated for commands in a [`RenderBundle`]. vertex_or_index_limit: u64, @@ -490,7 +509,7 @@ pub enum ArcRenderCommand { count_buffer: Arc, count_buffer_offset: BufferAddress, max_count: u32, - indexed: bool, + family: DrawCommandFamily, }, PushDebugGroup { #[cfg_attr(not(any(feature = "serde", feature = "replay")), allow(dead_code))] diff --git a/wgpu-core/src/device/global.rs b/wgpu-core/src/device/global.rs index a76b24537c..92f10fc07f 100644 --- a/wgpu-core/src/device/global.rs +++ b/wgpu-core/src/device/global.rs @@ -16,8 +16,9 @@ use crate::{ id::{self, AdapterId, DeviceId, QueueId, SurfaceId}, instance::{self, Adapter, Surface}, pipeline::{ - self, ResolvedComputePipelineDescriptor, ResolvedFragmentState, - ResolvedProgrammableStageDescriptor, ResolvedRenderPipelineDescriptor, ResolvedVertexState, + self, RenderPipelineVertexProcessor, ResolvedComputePipelineDescriptor, + ResolvedFragmentState, ResolvedGeneralRenderPipelineDescriptor, ResolvedMeshState, + ResolvedProgrammableStageDescriptor, ResolvedTaskState, ResolvedVertexState, }, present, resource::{ @@ -1346,17 +1347,55 @@ impl Global { let fid = hub.render_pipelines.prepare(id_in); + let device = self.hub.devices.get(device_id); + #[cfg(feature = "trace")] + if let Some(ref mut trace) = *device.trace.lock() { + trace.add(trace::Action::CreateRenderPipeline { + id: fid.id(), + desc: desc.clone(), + }); + } + self.device_create_general_render_pipeline(desc.clone().into(), device, fid) + } + + pub fn device_create_mesh_pipeline( + &self, + device_id: DeviceId, + desc: &pipeline::MeshPipelineDescriptor, + id_in: Option, + ) -> ( + id::RenderPipelineId, + Option, + ) { + let hub = &self.hub; + + let fid = hub.render_pipelines.prepare(id_in); + + let device = self.hub.devices.get(device_id); + #[cfg(feature = "trace")] + if let Some(ref mut trace) = *device.trace.lock() { + trace.add(trace::Action::CreateMeshPipeline { + id: fid.id(), + desc: desc.clone(), + }); + } + self.device_create_general_render_pipeline(desc.clone().into(), device, fid) + } + + fn device_create_general_render_pipeline( + &self, + desc: pipeline::GeneralRenderPipelineDescriptor, + device: Arc, + fid: crate::registry::FutureId>, + ) -> ( + id::RenderPipelineId, + Option, + ) { + profiling::scope!("Device::create_general_render_pipeline"); + + let hub = &self.hub; + let error = 'error: { - let device = self.hub.devices.get(device_id); - - #[cfg(feature = "trace")] - if let Some(ref mut trace) = *device.trace.lock() { - trace.add(trace::Action::CreateRenderPipeline { - id: fid.id(), - desc: desc.clone(), - }); - } - if let Err(e) = device.check_is_valid() { break 'error e.into(); } @@ -1379,31 +1418,83 @@ impl Global { Err(e) => break 'error e.into(), }; - let vertex = { - let module = hub - .shader_modules - .get(desc.vertex.stage.module) - .get() - .map_err(|e| pipeline::CreateRenderPipelineError::Stage { - stage: wgt::ShaderStages::VERTEX, - error: e.into(), - }); - let module = match module { - Ok(module) => module, - Err(e) => break 'error e, - }; - let stage = ResolvedProgrammableStageDescriptor { - module, - entry_point: desc.vertex.stage.entry_point.clone(), - constants: desc.vertex.stage.constants.clone(), - zero_initialize_workgroup_memory: desc - .vertex - .stage - .zero_initialize_workgroup_memory, - }; - ResolvedVertexState { - stage, - buffers: desc.vertex.buffers.clone(), + let vertex = match desc.vertex { + RenderPipelineVertexProcessor::Vertex(ref vertex) => { + let module = hub + .shader_modules + .get(vertex.stage.module) + .get() + .map_err(|e| pipeline::CreateRenderPipelineError::Stage { + stage: wgt::ShaderStages::VERTEX, + error: e.into(), + }); + let module = match module { + Ok(module) => module, + Err(e) => break 'error e, + }; + let stage = ResolvedProgrammableStageDescriptor { + module, + entry_point: vertex.stage.entry_point.clone(), + constants: vertex.stage.constants.clone(), + zero_initialize_workgroup_memory: vertex + .stage + .zero_initialize_workgroup_memory, + }; + RenderPipelineVertexProcessor::Vertex(ResolvedVertexState { + stage, + buffers: vertex.buffers.clone(), + }) + } + RenderPipelineVertexProcessor::Mesh(ref task, ref mesh) => { + let task_module = if let Some(task) = task { + let module = hub + .shader_modules + .get(task.stage.module) + .get() + .map_err(|e| pipeline::CreateRenderPipelineError::Stage { + stage: wgt::ShaderStages::VERTEX, + error: e.into(), + }); + let module = match module { + Ok(module) => module, + Err(e) => break 'error e, + }; + let state = ResolvedProgrammableStageDescriptor { + module, + entry_point: task.stage.entry_point.clone(), + constants: task.stage.constants.clone(), + zero_initialize_workgroup_memory: task + .stage + .zero_initialize_workgroup_memory, + }; + Some(ResolvedTaskState { stage: state }) + } else { + None + }; + let mesh_module = + hub.shader_modules + .get(mesh.stage.module) + .get() + .map_err(|e| pipeline::CreateRenderPipelineError::Stage { + stage: wgt::ShaderStages::MESH, + error: e.into(), + }); + let mesh_module = match mesh_module { + Ok(module) => module, + Err(e) => break 'error e, + }; + let mesh_stage = ResolvedProgrammableStageDescriptor { + module: mesh_module, + entry_point: mesh.stage.entry_point.clone(), + constants: mesh.stage.constants.clone(), + zero_initialize_workgroup_memory: mesh + .stage + .zero_initialize_workgroup_memory, + }; + RenderPipelineVertexProcessor::Mesh( + task_module, + ResolvedMeshState { stage: mesh_stage }, + ) } }; @@ -1424,10 +1515,7 @@ impl Global { module, entry_point: state.stage.entry_point.clone(), constants: state.stage.constants.clone(), - zero_initialize_workgroup_memory: desc - .vertex - .stage - .zero_initialize_workgroup_memory, + zero_initialize_workgroup_memory: state.stage.zero_initialize_workgroup_memory, }; Some(ResolvedFragmentState { stage, @@ -1437,7 +1525,7 @@ impl Global { None }; - let desc = ResolvedRenderPipelineDescriptor { + let desc = ResolvedGeneralRenderPipelineDescriptor { label: desc.label.clone(), layout, vertex, diff --git a/wgpu-core/src/device/resource.rs b/wgpu-core/src/device/resource.rs index 3156769440..ed03fce774 100644 --- a/wgpu-core/src/device/resource.rs +++ b/wgpu-core/src/device/resource.rs @@ -3472,7 +3472,7 @@ impl Device { pub(crate) fn create_render_pipeline( self: &Arc, - desc: pipeline::ResolvedRenderPipelineDescriptor, + desc: pipeline::ResolvedGeneralRenderPipelineDescriptor, ) -> Result, pipeline::CreateRenderPipelineError> { use wgt::TextureFormatFeatureFlags as Tfff; @@ -3513,127 +3513,137 @@ impl Device { let mut io = validation::StageIo::default(); let mut validated_stages = wgt::ShaderStages::empty(); - let mut vertex_steps = Vec::with_capacity(desc.vertex.buffers.len()); - let mut vertex_buffers = Vec::with_capacity(desc.vertex.buffers.len()); - let mut total_attributes = 0; + let mut vertex_steps; + let mut vertex_buffers; + let mut total_attributes; let mut shader_expects_dual_source_blending = false; let mut pipeline_expects_dual_source_blending = false; - for (i, vb_state) in desc.vertex.buffers.iter().enumerate() { - // https://gpuweb.github.io/gpuweb/#abstract-opdef-validating-gpuvertexbufferlayout + if let pipeline::RenderPipelineVertexProcessor::Vertex(ref vertex) = desc.vertex { + vertex_steps = Vec::with_capacity(vertex.buffers.len()); + vertex_buffers = Vec::with_capacity(vertex.buffers.len()); + total_attributes = 0; + shader_expects_dual_source_blending = false; + pipeline_expects_dual_source_blending = false; + for (i, vb_state) in vertex.buffers.iter().enumerate() { + // https://gpuweb.github.io/gpuweb/#abstract-opdef-validating-gpuvertexbufferlayout - if vb_state.array_stride > self.limits.max_vertex_buffer_array_stride as u64 { - return Err(pipeline::CreateRenderPipelineError::VertexStrideTooLarge { - index: i as u32, - given: vb_state.array_stride as u32, - limit: self.limits.max_vertex_buffer_array_stride, - }); - } - if vb_state.array_stride % wgt::VERTEX_ALIGNMENT != 0 { - return Err(pipeline::CreateRenderPipelineError::UnalignedVertexStride { - index: i as u32, + if vb_state.array_stride > self.limits.max_vertex_buffer_array_stride as u64 { + return Err(pipeline::CreateRenderPipelineError::VertexStrideTooLarge { + index: i as u32, + given: vb_state.array_stride as u32, + limit: self.limits.max_vertex_buffer_array_stride, + }); + } + if vb_state.array_stride % wgt::VERTEX_ALIGNMENT != 0 { + return Err(pipeline::CreateRenderPipelineError::UnalignedVertexStride { + index: i as u32, + stride: vb_state.array_stride, + }); + } + + let max_stride = if vb_state.array_stride == 0 { + self.limits.max_vertex_buffer_array_stride as u64 + } else { + vb_state.array_stride + }; + let mut last_stride = 0; + for attribute in vb_state.attributes.iter() { + let attribute_stride = attribute.offset + attribute.format.size(); + if attribute_stride > max_stride { + return Err( + pipeline::CreateRenderPipelineError::VertexAttributeStrideTooLarge { + location: attribute.shader_location, + given: attribute_stride as u32, + limit: max_stride as u32, + }, + ); + } + + let required_offset_alignment = attribute.format.size().min(4); + if attribute.offset % required_offset_alignment != 0 { + return Err( + pipeline::CreateRenderPipelineError::InvalidVertexAttributeOffset { + location: attribute.shader_location, + offset: attribute.offset, + }, + ); + } + + if attribute.shader_location >= self.limits.max_vertex_attributes { + return Err( + pipeline::CreateRenderPipelineError::TooManyVertexAttributes { + given: attribute.shader_location, + limit: self.limits.max_vertex_attributes, + }, + ); + } + + last_stride = last_stride.max(attribute_stride); + } + vertex_steps.push(pipeline::VertexStep { stride: vb_state.array_stride, + last_stride, + mode: vb_state.step_mode, + }); + if vb_state.attributes.is_empty() { + continue; + } + vertex_buffers.push(hal::VertexBufferLayout { + array_stride: vb_state.array_stride, + step_mode: vb_state.step_mode, + attributes: vb_state.attributes.as_ref(), + }); + + for attribute in vb_state.attributes.iter() { + if attribute.offset >= 0x10000000 { + return Err( + pipeline::CreateRenderPipelineError::InvalidVertexAttributeOffset { + location: attribute.shader_location, + offset: attribute.offset, + }, + ); + } + + if let wgt::VertexFormat::Float64 + | wgt::VertexFormat::Float64x2 + | wgt::VertexFormat::Float64x3 + | wgt::VertexFormat::Float64x4 = attribute.format + { + self.require_features(wgt::Features::VERTEX_ATTRIBUTE_64BIT)?; + } + + let previous = io.insert( + attribute.shader_location, + validation::InterfaceVar::vertex_attribute(attribute.format), + ); + + if previous.is_some() { + return Err(pipeline::CreateRenderPipelineError::ShaderLocationClash( + attribute.shader_location, + )); + } + } + total_attributes += vb_state.attributes.len(); + } + + if vertex_buffers.len() > self.limits.max_vertex_buffers as usize { + return Err(pipeline::CreateRenderPipelineError::TooManyVertexBuffers { + given: vertex_buffers.len() as u32, + limit: self.limits.max_vertex_buffers, }); } - - let max_stride = if vb_state.array_stride == 0 { - self.limits.max_vertex_buffer_array_stride as u64 - } else { - vb_state.array_stride - }; - let mut last_stride = 0; - for attribute in vb_state.attributes.iter() { - let attribute_stride = attribute.offset + attribute.format.size(); - if attribute_stride > max_stride { - return Err( - pipeline::CreateRenderPipelineError::VertexAttributeStrideTooLarge { - location: attribute.shader_location, - given: attribute_stride as u32, - limit: max_stride as u32, - }, - ); - } - - let required_offset_alignment = attribute.format.size().min(4); - if attribute.offset % required_offset_alignment != 0 { - return Err( - pipeline::CreateRenderPipelineError::InvalidVertexAttributeOffset { - location: attribute.shader_location, - offset: attribute.offset, - }, - ); - } - - if attribute.shader_location >= self.limits.max_vertex_attributes { - return Err( - pipeline::CreateRenderPipelineError::TooManyVertexAttributes { - given: attribute.shader_location, - limit: self.limits.max_vertex_attributes, - }, - ); - } - - last_stride = last_stride.max(attribute_stride); - } - vertex_steps.push(pipeline::VertexStep { - stride: vb_state.array_stride, - last_stride, - mode: vb_state.step_mode, - }); - if vb_state.attributes.is_empty() { - continue; - } - vertex_buffers.push(hal::VertexBufferLayout { - array_stride: vb_state.array_stride, - step_mode: vb_state.step_mode, - attributes: vb_state.attributes.as_ref(), - }); - - for attribute in vb_state.attributes.iter() { - if attribute.offset >= 0x10000000 { - return Err( - pipeline::CreateRenderPipelineError::InvalidVertexAttributeOffset { - location: attribute.shader_location, - offset: attribute.offset, - }, - ); - } - - if let wgt::VertexFormat::Float64 - | wgt::VertexFormat::Float64x2 - | wgt::VertexFormat::Float64x3 - | wgt::VertexFormat::Float64x4 = attribute.format - { - self.require_features(wgt::Features::VERTEX_ATTRIBUTE_64BIT)?; - } - - let previous = io.insert( - attribute.shader_location, - validation::InterfaceVar::vertex_attribute(attribute.format), + if total_attributes > self.limits.max_vertex_attributes as usize { + return Err( + pipeline::CreateRenderPipelineError::TooManyVertexAttributes { + given: total_attributes as u32, + limit: self.limits.max_vertex_attributes, + }, ); - - if previous.is_some() { - return Err(pipeline::CreateRenderPipelineError::ShaderLocationClash( - attribute.shader_location, - )); - } } - total_attributes += vb_state.attributes.len(); - } - - if vertex_buffers.len() > self.limits.max_vertex_buffers as usize { - return Err(pipeline::CreateRenderPipelineError::TooManyVertexBuffers { - given: vertex_buffers.len() as u32, - limit: self.limits.max_vertex_buffers, - }); - } - if total_attributes > self.limits.max_vertex_attributes as usize { - return Err( - pipeline::CreateRenderPipelineError::TooManyVertexAttributes { - given: total_attributes as u32, - limit: self.limits.max_vertex_attributes, - }, - ); - } + } else { + vertex_steps = Vec::new(); + vertex_buffers = Vec::new(); + }; if desc.primitive.strip_index_format.is_some() && !desc.primitive.topology.is_strip() { return Err( @@ -3843,44 +3853,132 @@ impl Device { sc }; - let vertex_entry_point_name; - let vertex_stage = { - let stage_desc = &desc.vertex.stage; - let stage = wgt::ShaderStages::VERTEX; + let mut vertex_stage = None; + let mut task_stage = None; + let mut mesh_stage = None; + let mut _vertex_entry_point_name = String::new(); + let mut _task_entry_point_name = String::new(); + let mut _mesh_entry_point_name = String::new(); + match desc.vertex { + pipeline::RenderPipelineVertexProcessor::Vertex(ref vertex) => { + vertex_stage = { + let stage_desc = &vertex.stage; + let stage = wgt::ShaderStages::VERTEX; - let vertex_shader_module = &stage_desc.module; - vertex_shader_module.same_device(self)?; + let vertex_shader_module = &stage_desc.module; + vertex_shader_module.same_device(self)?; - let stage_err = |error| pipeline::CreateRenderPipelineError::Stage { stage, error }; + let stage_err = + |error| pipeline::CreateRenderPipelineError::Stage { stage, error }; - vertex_entry_point_name = vertex_shader_module - .finalize_entry_point_name( - stage, - stage_desc.entry_point.as_ref().map(|ep| ep.as_ref()), - ) - .map_err(stage_err)?; + _vertex_entry_point_name = vertex_shader_module + .finalize_entry_point_name( + stage, + stage_desc.entry_point.as_ref().map(|ep| ep.as_ref()), + ) + .map_err(stage_err)?; - if let Some(ref interface) = vertex_shader_module.interface { - io = interface - .check_stage( - &mut binding_layout_source, - &mut shader_binding_sizes, - &vertex_entry_point_name, - stage, - io, - desc.depth_stencil.as_ref().map(|d| d.depth_compare), - ) - .map_err(stage_err)?; - validated_stages |= stage; + if let Some(ref interface) = vertex_shader_module.interface { + io = interface + .check_stage( + &mut binding_layout_source, + &mut shader_binding_sizes, + &_vertex_entry_point_name, + stage, + io, + desc.depth_stencil.as_ref().map(|d| d.depth_compare), + ) + .map_err(stage_err)?; + validated_stages |= stage; + } + Some(hal::ProgrammableStage { + module: vertex_shader_module.raw(), + entry_point: &_vertex_entry_point_name, + constants: &stage_desc.constants, + zero_initialize_workgroup_memory: stage_desc + .zero_initialize_workgroup_memory, + }) + }; } + pipeline::RenderPipelineVertexProcessor::Mesh(ref task, ref mesh) => { + task_stage = if let Some(task) = task { + let stage_desc = &task.stage; + let stage = wgt::ShaderStages::TASK; + let task_shader_module = &stage_desc.module; + task_shader_module.same_device(self)?; - hal::ProgrammableStage { - module: vertex_shader_module.raw(), - entry_point: &vertex_entry_point_name, - constants: &stage_desc.constants, - zero_initialize_workgroup_memory: stage_desc.zero_initialize_workgroup_memory, + let stage_err = + |error| pipeline::CreateRenderPipelineError::Stage { stage, error }; + + _task_entry_point_name = task_shader_module + .finalize_entry_point_name( + stage, + stage_desc.entry_point.as_ref().map(|ep| ep.as_ref()), + ) + .map_err(stage_err)?; + + if let Some(ref interface) = task_shader_module.interface { + io = interface + .check_stage( + &mut binding_layout_source, + &mut shader_binding_sizes, + &_task_entry_point_name, + stage, + io, + desc.depth_stencil.as_ref().map(|d| d.depth_compare), + ) + .map_err(stage_err)?; + validated_stages |= stage; + } + Some(hal::ProgrammableStage { + module: task_shader_module.raw(), + entry_point: &_task_entry_point_name, + constants: &stage_desc.constants, + zero_initialize_workgroup_memory: stage_desc + .zero_initialize_workgroup_memory, + }) + } else { + None + }; + mesh_stage = { + let stage_desc = &mesh.stage; + let stage = wgt::ShaderStages::MESH; + let mesh_shader_module = &stage_desc.module; + mesh_shader_module.same_device(self)?; + + let stage_err = + |error| pipeline::CreateRenderPipelineError::Stage { stage, error }; + + _mesh_entry_point_name = mesh_shader_module + .finalize_entry_point_name( + stage, + stage_desc.entry_point.as_ref().map(|ep| ep.as_ref()), + ) + .map_err(stage_err)?; + + if let Some(ref interface) = mesh_shader_module.interface { + io = interface + .check_stage( + &mut binding_layout_source, + &mut shader_binding_sizes, + &_mesh_entry_point_name, + stage, + io, + desc.depth_stencil.as_ref().map(|d| d.depth_compare), + ) + .map_err(stage_err)?; + validated_stages |= stage; + } + Some(hal::ProgrammableStage { + module: mesh_shader_module.raw(), + entry_point: &_mesh_entry_point_name, + constants: &stage_desc.constants, + zero_initialize_workgroup_memory: stage_desc + .zero_initialize_workgroup_memory, + }) + }; } - }; + } let fragment_entry_point_name; let fragment_stage = match desc.fragment { @@ -4029,20 +4127,29 @@ impl Device { None => None, }; - let pipeline_desc = hal::RenderPipelineDescriptor { - label: desc.label.to_hal(self.instance_flags), - layout: pipeline_layout.raw(), - vertex_buffers: &vertex_buffers, - vertex_stage, - primitive: desc.primitive, - depth_stencil: desc.depth_stencil.clone(), - multisample: desc.multisample, - fragment_stage, - color_targets, - multiview: desc.multiview, - cache: cache.as_ref().map(|it| it.raw()), - }; - let raw = + let is_mesh = mesh_stage.is_some(); + let raw = { + let pipeline_desc = hal::RenderPipelineDescriptor { + label: desc.label.to_hal(self.instance_flags), + layout: pipeline_layout.raw(), + vertex_processor: match vertex_stage { + Some(vertex_stage) => hal::VertexProcessor::Standard { + vertex_buffers: &vertex_buffers, + vertex_stage, + }, + None => hal::VertexProcessor::Mesh { + task_stage, + mesh_stage: mesh_stage.unwrap(), + }, + }, + primitive: desc.primitive, + depth_stencil: desc.depth_stencil.clone(), + multisample: desc.multisample, + fragment_stage, + color_targets, + multiview: desc.multiview, + cache: cache.as_ref().map(|it| it.raw()), + }; unsafe { self.raw().create_render_pipeline(&pipeline_desc) }.map_err( |err| match err { hal::PipelineError::Device(error) => { @@ -4061,7 +4168,8 @@ impl Device { pipeline::CreateRenderPipelineError::PipelineConstants { stage, error } } }, - )?; + )? + }; let pass_context = RenderPassContext { attachments: AttachmentData { @@ -4095,10 +4203,19 @@ impl Device { flags |= pipeline::PipelineFlags::WRITES_STENCIL; } } - let shader_modules = { let mut shader_modules = ArrayVec::new(); - shader_modules.push(desc.vertex.stage.module); + match desc.vertex { + pipeline::RenderPipelineVertexProcessor::Vertex(vertex) => { + shader_modules.push(vertex.stage.module) + } + pipeline::RenderPipelineVertexProcessor::Mesh(task, mesh) => { + if let Some(task) = task { + shader_modules.push(task.stage.module); + } + shader_modules.push(mesh.stage.module); + } + } shader_modules.extend(desc.fragment.map(|f| f.stage.module)); shader_modules }; @@ -4115,6 +4232,7 @@ impl Device { late_sized_buffer_groups, label: desc.label.to_string(), tracking_data: TrackingData::new(self.tracker_indices.render_pipelines.clone()), + is_mesh, }; let pipeline = Arc::new(pipeline); diff --git a/wgpu-core/src/device/trace.rs b/wgpu-core/src/device/trace.rs index 602264b5c3..58d26e4b07 100644 --- a/wgpu-core/src/device/trace.rs +++ b/wgpu-core/src/device/trace.rs @@ -103,6 +103,10 @@ pub enum Action<'a> { id: id::RenderPipelineId, desc: crate::pipeline::RenderPipelineDescriptor<'a>, }, + CreateMeshPipeline { + id: id::RenderPipelineId, + desc: crate::pipeline::MeshPipelineDescriptor<'a>, + }, DestroyRenderPipeline(id::RenderPipelineId), CreatePipelineCache { id: id::PipelineCacheId, diff --git a/wgpu-core/src/indirect_validation/draw.rs b/wgpu-core/src/indirect_validation/draw.rs index 886fc5bc9b..db23b0469d 100644 --- a/wgpu-core/src/indirect_validation/draw.rs +++ b/wgpu-core/src/indirect_validation/draw.rs @@ -919,7 +919,7 @@ impl DrawBatcher { device: &Device, src_buffer: &Arc, offset: u64, - indexed: bool, + family: crate::command::DrawCommandFamily, vertex_or_index_limit: u64, instance_limit: u64, ) -> Result<(usize, u64), DeviceError> { @@ -929,7 +929,7 @@ impl DrawBatcher { } else { 0 }; - let stride = extra + crate::command::get_stride_of_indirect_args(indexed); + let stride = extra + crate::command::get_stride_of_indirect_args(family); let (dst_resource_index, dst_offset) = indirect_draw_validation_resources .get_dst_subrange(stride, &mut self.current_dst_entry)?; @@ -941,7 +941,7 @@ impl DrawBatcher { let src_buffer_tracker_index = src_buffer.tracker_index(); let entry = MetadataEntry::new( - indexed, + family == crate::command::DrawCommandFamily::DrawIndexed, src_offset, dst_offset, vertex_or_index_limit, diff --git a/wgpu-core/src/pipeline.rs b/wgpu-core/src/pipeline.rs index 7d055b7052..f67cc65e2a 100644 --- a/wgpu-core/src/pipeline.rs +++ b/wgpu-core/src/pipeline.rs @@ -402,6 +402,33 @@ pub struct FragmentState<'a, SM = ShaderModuleId> { /// cbindgen:ignore pub type ResolvedFragmentState<'a> = FragmentState<'a, Arc>; +/// Describes the task shader in a mesh shader pipeline. +#[derive(Clone, Debug)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub struct TaskState<'a, SM = ShaderModuleId> { + /// The compiled task stage and its entry point. + pub stage: ProgrammableStageDescriptor<'a, SM>, +} + +pub type ResolvedTaskState<'a> = TaskState<'a, Arc>; + +/// Describes the mesh shader in a mesh shader pipeline. +#[derive(Clone, Debug)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub struct MeshState<'a, SM = ShaderModuleId> { + /// The compiled mesh stage and its entry point. + pub stage: ProgrammableStageDescriptor<'a, SM>, +} + +pub type ResolvedMeshState<'a> = MeshState<'a, Arc>; + +#[derive(Clone, Debug)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub(crate) enum RenderPipelineVertexProcessor<'a, SM = ShaderModuleId> { + Vertex(VertexState<'a, SM>), + Mesh(Option>, MeshState<'a, SM>), +} + /// Describes a render (graphics) pipeline. #[derive(Clone, Debug)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] @@ -433,10 +460,109 @@ pub struct RenderPipelineDescriptor< /// The pipeline cache to use when creating this pipeline. pub cache: Option, } +/// Describes a mesh shader pipeline. +#[derive(Clone, Debug)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub struct MeshPipelineDescriptor< + 'a, + PLL = PipelineLayoutId, + SM = ShaderModuleId, + PLC = PipelineCacheId, +> { + pub label: Label<'a>, + /// The layout of bind groups for this pipeline. + pub layout: Option, + /// The task processing state for this pipeline. + pub task: Option>, + /// The mesh processing state for this pipeline + pub mesh: MeshState<'a, SM>, + /// The properties of the pipeline at the primitive assembly and rasterization level. + #[cfg_attr(feature = "serde", serde(default))] + pub primitive: wgt::PrimitiveState, + /// The effect of draw calls on the depth and stencil aspects of the output target, if any. + #[cfg_attr(feature = "serde", serde(default))] + pub depth_stencil: Option, + /// The multi-sampling properties of the pipeline. + #[cfg_attr(feature = "serde", serde(default))] + pub multisample: wgt::MultisampleState, + /// The fragment processing state for this pipeline. + pub fragment: Option>, + /// If the pipeline will be used with a multiview render pass, this indicates how many array + /// layers the attachments will have. + pub multiview: Option, + /// The pipeline cache to use when creating this pipeline. + pub cache: Option, +} + +/// Describes a render (graphics) pipeline. +#[derive(Clone, Debug)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub(crate) struct GeneralRenderPipelineDescriptor< + 'a, + PLL = PipelineLayoutId, + SM = ShaderModuleId, + PLC = PipelineCacheId, +> { + pub label: Label<'a>, + /// The layout of bind groups for this pipeline. + pub layout: Option, + /// The vertex processing state for this pipeline. + pub vertex: RenderPipelineVertexProcessor<'a, SM>, + /// The properties of the pipeline at the primitive assembly and rasterization level. + #[cfg_attr(feature = "serde", serde(default))] + pub primitive: wgt::PrimitiveState, + /// The effect of draw calls on the depth and stencil aspects of the output target, if any. + #[cfg_attr(feature = "serde", serde(default))] + pub depth_stencil: Option, + /// The multi-sampling properties of the pipeline. + #[cfg_attr(feature = "serde", serde(default))] + pub multisample: wgt::MultisampleState, + /// The fragment processing state for this pipeline. + pub fragment: Option>, + /// If the pipeline will be used with a multiview render pass, this indicates how many array + /// layers the attachments will have. + pub multiview: Option, + /// The pipeline cache to use when creating this pipeline. + pub cache: Option, +} +impl<'a, PLL, SM, PLC> From> + for GeneralRenderPipelineDescriptor<'a, PLL, SM, PLC> +{ + fn from(value: RenderPipelineDescriptor<'a, PLL, SM, PLC>) -> Self { + Self { + label: value.label, + layout: value.layout, + vertex: RenderPipelineVertexProcessor::Vertex(value.vertex), + primitive: value.primitive, + depth_stencil: value.depth_stencil, + multisample: value.multisample, + fragment: value.fragment, + multiview: value.multiview, + cache: value.cache, + } + } +} +impl<'a, PLL, SM, PLC> From> + for GeneralRenderPipelineDescriptor<'a, PLL, SM, PLC> +{ + fn from(value: MeshPipelineDescriptor<'a, PLL, SM, PLC>) -> Self { + Self { + label: value.label, + layout: value.layout, + vertex: RenderPipelineVertexProcessor::Mesh(value.task, value.mesh), + primitive: value.primitive, + depth_stencil: value.depth_stencil, + multisample: value.multisample, + fragment: value.fragment, + multiview: value.multiview, + cache: value.cache, + } + } +} /// cbindgen:ignore -pub type ResolvedRenderPipelineDescriptor<'a> = - RenderPipelineDescriptor<'a, Arc, Arc, Arc>; +pub(crate) type ResolvedGeneralRenderPipelineDescriptor<'a> = + GeneralRenderPipelineDescriptor<'a, Arc, Arc, Arc>; #[derive(Clone, Debug)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] @@ -649,6 +775,8 @@ pub struct RenderPipeline { /// The `label` from the descriptor used to create the resource. pub(crate) label: String, pub(crate) tracking_data: TrackingData, + /// Whether this is a mesh shader pipeline + pub(crate) is_mesh: bool, } impl Drop for RenderPipeline { diff --git a/wgpu-core/src/validation.rs b/wgpu-core/src/validation.rs index 5bddd8a5b0..2127604064 100644 --- a/wgpu-core/src/validation.rs +++ b/wgpu-core/src/validation.rs @@ -1288,6 +1288,7 @@ impl Interface { ) } naga::ShaderStage::Compute => (false, 0), + // TODO: add validation for these, see https://github.com/gfx-rs/wgpu/issues/8003 naga::ShaderStage::Task | naga::ShaderStage::Mesh => { unreachable!() } diff --git a/wgpu-hal/examples/halmark/main.rs b/wgpu-hal/examples/halmark/main.rs index 36675e3b70..22f211c909 100644 --- a/wgpu-hal/examples/halmark/main.rs +++ b/wgpu-hal/examples/halmark/main.rs @@ -254,13 +254,15 @@ impl Example { let pipeline_desc = hal::RenderPipelineDescriptor { label: None, layout: &pipeline_layout, - vertex_stage: hal::ProgrammableStage { - module: &shader, - entry_point: "vs_main", - constants: &constants, - zero_initialize_workgroup_memory: true, + vertex_processor: hal::VertexProcessor::Standard { + vertex_stage: hal::ProgrammableStage { + module: &shader, + entry_point: "vs_main", + constants: &constants, + zero_initialize_workgroup_memory: true, + }, + vertex_buffers: &[], }, - vertex_buffers: &[], fragment_stage: Some(hal::ProgrammableStage { module: &shader, entry_point: "fs_main", diff --git a/wgpu-hal/src/dx12/adapter.rs b/wgpu-hal/src/dx12/adapter.rs index 5c0bee0173..22d52be4d7 100644 --- a/wgpu-hal/src/dx12/adapter.rs +++ b/wgpu-hal/src/dx12/adapter.rs @@ -613,6 +613,12 @@ impl super::Adapter { // store buffer sizes using 32 bit ints (a situation we have already encountered with vulkan). max_buffer_size: i32::MAX as u64, max_non_sampler_bindings: 1_000_000, + + max_task_workgroup_total_count: 0, + max_task_workgroups_per_dimension: 0, + max_mesh_multiview_count: 0, + max_mesh_output_layers: 0, + max_blas_primitive_count: if supports_ray_tracing { 1 << 29 // 2^29 } else { diff --git a/wgpu-hal/src/dx12/device.rs b/wgpu-hal/src/dx12/device.rs index 7d245e8ce3..5d3a6ac8d4 100644 --- a/wgpu-hal/src/dx12/device.rs +++ b/wgpu-hal/src/dx12/device.rs @@ -1747,8 +1747,16 @@ impl crate::Device for super::Device { let (topology_class, topology) = conv::map_topology(desc.primitive.topology); let mut shader_stages = wgt::ShaderStages::VERTEX; + let (vertex_stage_desc, vertex_buffers_desc) = match &desc.vertex_processor { + crate::VertexProcessor::Standard { + vertex_buffers, + vertex_stage, + } => (vertex_stage, *vertex_buffers), + crate::VertexProcessor::Mesh { .. } => unreachable!(), + }; + let blob_vs = self.load_shader( - &desc.vertex_stage, + vertex_stage_desc, desc.layout, naga::ShaderStage::Vertex, desc.fragment_stage.as_ref(), @@ -1765,7 +1773,7 @@ impl crate::Device for super::Device { let mut input_element_descs = Vec::new(); for (i, (stride, vbuf)) in vertex_strides .iter_mut() - .zip(desc.vertex_buffers) + .zip(vertex_buffers_desc) .enumerate() { *stride = NonZeroU32::new(vbuf.array_stride as u32); @@ -1919,17 +1927,6 @@ impl crate::Device for super::Device { }) } - unsafe fn create_mesh_pipeline( - &self, - _desc: &crate::MeshPipelineDescriptor< - ::PipelineLayout, - ::ShaderModule, - ::PipelineCache, - >, - ) -> Result<::RenderPipeline, crate::PipelineError> { - unreachable!() - } - unsafe fn destroy_render_pipeline(&self, _pipeline: super::RenderPipeline) { self.counters.render_pipelines.sub(1); } diff --git a/wgpu-hal/src/dynamic/device.rs b/wgpu-hal/src/dynamic/device.rs index de66b1619f..1f6ed91268 100644 --- a/wgpu-hal/src/dynamic/device.rs +++ b/wgpu-hal/src/dynamic/device.rs @@ -4,10 +4,10 @@ use crate::{ AccelerationStructureBuildSizes, AccelerationStructureDescriptor, Api, BindGroupDescriptor, BindGroupLayoutDescriptor, BufferDescriptor, BufferMapping, CommandEncoderDescriptor, ComputePipelineDescriptor, Device, DeviceError, FenceValue, - GetAccelerationStructureBuildSizesDescriptor, Label, MemoryRange, MeshPipelineDescriptor, - PipelineCacheDescriptor, PipelineCacheError, PipelineError, PipelineLayoutDescriptor, - RenderPipelineDescriptor, SamplerDescriptor, ShaderError, ShaderInput, ShaderModuleDescriptor, - TextureDescriptor, TextureViewDescriptor, TlasInstance, + GetAccelerationStructureBuildSizesDescriptor, Label, MemoryRange, PipelineCacheDescriptor, + PipelineCacheError, PipelineError, PipelineLayoutDescriptor, RenderPipelineDescriptor, + SamplerDescriptor, ShaderError, ShaderInput, ShaderModuleDescriptor, TextureDescriptor, + TextureViewDescriptor, TlasInstance, }; use super::{ @@ -100,14 +100,6 @@ pub trait DynDevice: DynResource { dyn DynPipelineCache, >, ) -> Result, PipelineError>; - unsafe fn create_mesh_pipeline( - &self, - desc: &MeshPipelineDescriptor< - dyn DynPipelineLayout, - dyn DynShaderModule, - dyn DynPipelineCache, - >, - ) -> Result, PipelineError>; unsafe fn destroy_render_pipeline(&self, pipeline: Box); unsafe fn create_compute_pipeline( @@ -394,8 +386,22 @@ impl DynDevice for D { let desc = RenderPipelineDescriptor { label: desc.label, layout: desc.layout.expect_downcast_ref(), - vertex_buffers: desc.vertex_buffers, - vertex_stage: desc.vertex_stage.clone().expect_downcast(), + vertex_processor: match &desc.vertex_processor { + crate::VertexProcessor::Standard { + vertex_buffers, + vertex_stage, + } => crate::VertexProcessor::Standard { + vertex_buffers, + vertex_stage: vertex_stage.clone().expect_downcast(), + }, + crate::VertexProcessor::Mesh { + task_stage: task, + mesh_stage: mesh, + } => crate::VertexProcessor::Mesh { + task_stage: task.as_ref().map(|a| a.clone().expect_downcast()), + mesh_stage: mesh.clone().expect_downcast(), + }, + }, primitive: desc.primitive, depth_stencil: desc.depth_stencil.clone(), multisample: desc.multisample, @@ -409,32 +415,6 @@ impl DynDevice for D { .map(|b| -> Box { Box::new(b) }) } - unsafe fn create_mesh_pipeline( - &self, - desc: &MeshPipelineDescriptor< - dyn DynPipelineLayout, - dyn DynShaderModule, - dyn DynPipelineCache, - >, - ) -> Result, PipelineError> { - let desc = MeshPipelineDescriptor { - label: desc.label, - layout: desc.layout.expect_downcast_ref(), - task_stage: desc.task_stage.clone().map(|f| f.expect_downcast()), - mesh_stage: desc.mesh_stage.clone().expect_downcast(), - primitive: desc.primitive, - depth_stencil: desc.depth_stencil.clone(), - multisample: desc.multisample, - fragment_stage: desc.fragment_stage.clone().map(|f| f.expect_downcast()), - color_targets: desc.color_targets, - multiview: desc.multiview, - cache: desc.cache.map(|c| c.expect_downcast_ref()), - }; - - unsafe { D::create_mesh_pipeline(self, &desc) } - .map(|b| -> Box { Box::new(b) }) - } - unsafe fn destroy_render_pipeline(&self, pipeline: Box) { unsafe { D::destroy_render_pipeline(self, pipeline.unbox()) }; } diff --git a/wgpu-hal/src/gles/adapter.rs b/wgpu-hal/src/gles/adapter.rs index d9d2812ccd..865f35013c 100644 --- a/wgpu-hal/src/gles/adapter.rs +++ b/wgpu-hal/src/gles/adapter.rs @@ -801,6 +801,12 @@ impl super::Adapter { max_compute_workgroups_per_dimension, max_buffer_size: i32::MAX as u64, max_non_sampler_bindings: u32::MAX, + + max_task_workgroup_total_count: 0, + max_task_workgroups_per_dimension: 0, + max_mesh_multiview_count: 0, + max_mesh_output_layers: 0, + max_blas_primitive_count: 0, max_blas_geometry_count: 0, max_tlas_instance_count: 0, diff --git a/wgpu-hal/src/gles/device.rs b/wgpu-hal/src/gles/device.rs index be9cc03fbe..0b1b77e11b 100644 --- a/wgpu-hal/src/gles/device.rs +++ b/wgpu-hal/src/gles/device.rs @@ -1363,9 +1363,16 @@ impl crate::Device for super::Device { super::PipelineCache, >, ) -> Result { + let (vertex_stage, vertex_buffers) = match &desc.vertex_processor { + crate::VertexProcessor::Standard { + vertex_buffers, + ref vertex_stage, + } => (vertex_stage, vertex_buffers), + crate::VertexProcessor::Mesh { .. } => unreachable!(), + }; let gl = &self.shared.context.lock(); let mut shaders = ArrayVec::new(); - shaders.push((naga::ShaderStage::Vertex, &desc.vertex_stage)); + shaders.push((naga::ShaderStage::Vertex, vertex_stage)); if let Some(ref fs) = desc.fragment_stage { shaders.push((naga::ShaderStage::Fragment, fs)); } @@ -1375,7 +1382,7 @@ impl crate::Device for super::Device { let (vertex_buffers, vertex_attributes) = { let mut buffers = Vec::new(); let mut attributes = Vec::new(); - for (index, vb_layout) in desc.vertex_buffers.iter().enumerate() { + for (index, vb_layout) in vertex_buffers.iter().enumerate() { buffers.push(super::VertexBufferDesc { step: vb_layout.step_mode, stride: vb_layout.array_stride as u32, @@ -1430,16 +1437,6 @@ impl crate::Device for super::Device { alpha_to_coverage_enabled: desc.multisample.alpha_to_coverage_enabled, }) } - unsafe fn create_mesh_pipeline( - &self, - _desc: &crate::MeshPipelineDescriptor< - ::PipelineLayout, - ::ShaderModule, - ::PipelineCache, - >, - ) -> Result<::RenderPipeline, crate::PipelineError> { - unreachable!() - } unsafe fn destroy_render_pipeline(&self, pipeline: super::RenderPipeline) { // If the pipeline only has 2 strong references remaining, they're `pipeline` and `program_cache` diff --git a/wgpu-hal/src/lib.rs b/wgpu-hal/src/lib.rs index 1e92bbece3..e6b4e0b0f8 100644 --- a/wgpu-hal/src/lib.rs +++ b/wgpu-hal/src/lib.rs @@ -931,15 +931,6 @@ pub trait Device: WasmNotSendSync { ::PipelineCache, >, ) -> Result<::RenderPipeline, PipelineError>; - #[allow(clippy::type_complexity)] - unsafe fn create_mesh_pipeline( - &self, - desc: &MeshPipelineDescriptor< - ::PipelineLayout, - ::ShaderModule, - ::PipelineCache, - >, - ) -> Result<::RenderPipeline, PipelineError>; unsafe fn destroy_render_pipeline(&self, pipeline: ::RenderPipeline); #[allow(clippy::type_complexity)] @@ -2323,6 +2314,20 @@ pub struct VertexBufferLayout<'a> { pub attributes: &'a [wgt::VertexAttribute], } +#[derive(Clone, Debug)] +pub enum VertexProcessor<'a, M: DynShaderModule + ?Sized> { + Standard { + /// The format of any vertex buffers used with this pipeline. + vertex_buffers: &'a [VertexBufferLayout<'a>], + /// The vertex stage for this pipeline. + vertex_stage: ProgrammableStage<'a, M>, + }, + Mesh { + task_stage: Option>, + mesh_stage: ProgrammableStage<'a, M>, + }, +} + /// Describes a render (graphics) pipeline. #[derive(Clone, Debug)] pub struct RenderPipelineDescriptor< @@ -2334,37 +2339,8 @@ pub struct RenderPipelineDescriptor< pub label: Label<'a>, /// The layout of bind groups for this pipeline. pub layout: &'a Pl, - /// The format of any vertex buffers used with this pipeline. - pub vertex_buffers: &'a [VertexBufferLayout<'a>], - /// The vertex stage for this pipeline. - pub vertex_stage: ProgrammableStage<'a, M>, - /// The properties of the pipeline at the primitive assembly and rasterization level. - pub primitive: wgt::PrimitiveState, - /// The effect of draw calls on the depth and stencil aspects of the output target, if any. - pub depth_stencil: Option, - /// The multi-sampling properties of the pipeline. - pub multisample: wgt::MultisampleState, - /// The fragment stage for this pipeline. - pub fragment_stage: Option>, - /// The effect of draw calls on the color aspect of the output target. - pub color_targets: &'a [Option], - /// If the pipeline will be used with a multiview render pass, this indicates how many array - /// layers the attachments will have. - pub multiview: Option, - /// The cache which will be used and filled when compiling this pipeline - pub cache: Option<&'a Pc>, -} -pub struct MeshPipelineDescriptor< - 'a, - Pl: DynPipelineLayout + ?Sized, - M: DynShaderModule + ?Sized, - Pc: DynPipelineCache + ?Sized, -> { - pub label: Label<'a>, - /// The layout of bind groups for this pipeline. - pub layout: &'a Pl, - pub task_stage: Option>, - pub mesh_stage: ProgrammableStage<'a, M>, + /// The vertex processing state(vertex shader + buffers or task + mesh shaders) + pub vertex_processor: VertexProcessor<'a, M>, /// The properties of the pipeline at the primitive assembly and rasterization level. pub primitive: wgt::PrimitiveState, /// The effect of draw calls on the depth and stencil aspects of the output target, if any. diff --git a/wgpu-hal/src/metal/adapter.rs b/wgpu-hal/src/metal/adapter.rs index 6ecbff679f..7ab14ca76d 100644 --- a/wgpu-hal/src/metal/adapter.rs +++ b/wgpu-hal/src/metal/adapter.rs @@ -1077,6 +1077,12 @@ impl super::PrivateCapabilities { max_compute_workgroups_per_dimension: 0xFFFF, max_buffer_size: self.max_buffer_size, max_non_sampler_bindings: u32::MAX, + + max_task_workgroup_total_count: 0, + max_task_workgroups_per_dimension: 0, + max_mesh_multiview_count: 0, + max_mesh_output_layers: 0, + max_blas_primitive_count: 0, // When added: 2^28 from https://developer.apple.com/documentation/metal/mtlaccelerationstructureusage/extendedlimits max_blas_geometry_count: 0, // When added: 2^24 max_tlas_instance_count: 0, // When added: 2^24 diff --git a/wgpu-hal/src/metal/device.rs b/wgpu-hal/src/metal/device.rs index f4a79de640..fbb166c272 100644 --- a/wgpu-hal/src/metal/device.rs +++ b/wgpu-hal/src/metal/device.rs @@ -1056,6 +1056,14 @@ impl crate::Device for super::Device { super::PipelineCache, >, ) -> Result { + let (desc_vertex_stage, desc_vertex_buffers) = match &desc.vertex_processor { + crate::VertexProcessor::Standard { + vertex_buffers, + vertex_stage, + } => (vertex_stage, *vertex_buffers), + crate::VertexProcessor::Mesh { .. } => unreachable!(), + }; + objc::rc::autoreleasepool(|| { let descriptor = metal::RenderPipelineDescriptor::new(); @@ -1074,7 +1082,7 @@ impl crate::Device for super::Device { // Vertex shader let (vs_lib, vs_info) = { let mut vertex_buffer_mappings = Vec::::new(); - for (i, vbl) in desc.vertex_buffers.iter().enumerate() { + for (i, vbl) in desc_vertex_buffers.iter().enumerate() { let mut attributes = Vec::::new(); for attribute in vbl.attributes.iter() { attributes.push(naga::back::msl::AttributeMapping { @@ -1103,7 +1111,7 @@ impl crate::Device for super::Device { } let vs = self.load_shader( - &desc.vertex_stage, + desc_vertex_stage, &vertex_buffer_mappings, desc.layout, primitive_class, @@ -1216,12 +1224,12 @@ impl crate::Device for super::Device { None => None, }; - if desc.layout.total_counters.vs.buffers + (desc.vertex_buffers.len() as u32) + if desc.layout.total_counters.vs.buffers + (desc_vertex_buffers.len() as u32) > self.shared.private_caps.max_vertex_buffers { let msg = format!( "pipeline needs too many buffers in the vertex stage: {} vertex and {} layout", - desc.vertex_buffers.len(), + desc_vertex_buffers.len(), desc.layout.total_counters.vs.buffers ); return Err(crate::PipelineError::Linkage( @@ -1230,9 +1238,9 @@ impl crate::Device for super::Device { )); } - if !desc.vertex_buffers.is_empty() { + if !desc_vertex_buffers.is_empty() { let vertex_descriptor = metal::VertexDescriptor::new(); - for (i, vb) in desc.vertex_buffers.iter().enumerate() { + for (i, vb) in desc_vertex_buffers.iter().enumerate() { let buffer_index = self.shared.private_caps.max_vertex_buffers as u64 - 1 - i as u64; let buffer_desc = vertex_descriptor.layouts().object_at(buffer_index).unwrap(); @@ -1318,17 +1326,6 @@ impl crate::Device for super::Device { }) } - unsafe fn create_mesh_pipeline( - &self, - _desc: &crate::MeshPipelineDescriptor< - ::PipelineLayout, - ::ShaderModule, - ::PipelineCache, - >, - ) -> Result<::RenderPipeline, crate::PipelineError> { - unreachable!() - } - unsafe fn destroy_render_pipeline(&self, _pipeline: super::RenderPipeline) { self.counters.render_pipelines.sub(1); } diff --git a/wgpu-hal/src/noop/mod.rs b/wgpu-hal/src/noop/mod.rs index ef2d51294d..55965a7e2f 100644 --- a/wgpu-hal/src/noop/mod.rs +++ b/wgpu-hal/src/noop/mod.rs @@ -179,6 +179,12 @@ const CAPABILITIES: crate::Capabilities = { max_subgroup_size: ALLOC_MAX_U32, max_push_constant_size: ALLOC_MAX_U32, max_non_sampler_bindings: ALLOC_MAX_U32, + + max_task_workgroup_total_count: 0, + max_task_workgroups_per_dimension: 0, + max_mesh_multiview_count: 0, + max_mesh_output_layers: 0, + max_blas_primitive_count: ALLOC_MAX_U32, max_blas_geometry_count: ALLOC_MAX_U32, max_tlas_instance_count: ALLOC_MAX_U32, @@ -375,16 +381,6 @@ impl crate::Device for Context { ) -> Result { Ok(Resource) } - unsafe fn create_mesh_pipeline( - &self, - desc: &crate::MeshPipelineDescriptor< - ::PipelineLayout, - ::ShaderModule, - ::PipelineCache, - >, - ) -> Result<::RenderPipeline, crate::PipelineError> { - Ok(Resource) - } unsafe fn destroy_render_pipeline(&self, pipeline: Resource) {} unsafe fn create_compute_pipeline( &self, diff --git a/wgpu-hal/src/vulkan/adapter.rs b/wgpu-hal/src/vulkan/adapter.rs index a06616c613..c2b05f91c7 100644 --- a/wgpu-hal/src/vulkan/adapter.rs +++ b/wgpu-hal/src/vulkan/adapter.rs @@ -932,7 +932,7 @@ pub struct PhysicalDeviceProperties { /// Additional `vk::PhysicalDevice` properties from the /// `VK_EXT_mesh_shader` extension. - _mesh_shader: Option>, + mesh_shader: Option>, /// The device API version. /// @@ -1160,6 +1160,20 @@ impl PhysicalDeviceProperties { let max_compute_workgroups_per_dimension = limits.max_compute_work_group_count[0] .min(limits.max_compute_work_group_count[1]) .min(limits.max_compute_work_group_count[2]); + let ( + max_task_workgroup_total_count, + max_task_workgroups_per_dimension, + max_mesh_multiview_count, + max_mesh_output_layers, + ) = match self.mesh_shader { + Some(m) => ( + m.max_task_work_group_total_count, + m.max_task_work_group_count.into_iter().min().unwrap(), + m.max_mesh_multiview_view_count, + m.max_mesh_output_layers, + ), + None => (0, 0, 0, 0), + }; // Prevent very large buffers on mesa and most android devices. let is_nvidia = self.properties.vendor_id == crate::auxil::db::nvidia::VENDOR; @@ -1267,6 +1281,12 @@ impl PhysicalDeviceProperties { max_compute_workgroups_per_dimension, max_buffer_size, max_non_sampler_bindings: u32::MAX, + + max_task_workgroup_total_count, + max_task_workgroups_per_dimension, + max_mesh_multiview_count, + max_mesh_output_layers, + max_blas_primitive_count, max_blas_geometry_count, max_tlas_instance_count, @@ -1401,7 +1421,7 @@ impl super::InstanceShared { if supports_mesh_shader { let next = capabilities - ._mesh_shader + .mesh_shader .insert(vk::PhysicalDeviceMeshShaderPropertiesEXT::default()); properties2 = properties2.push_next(next); } diff --git a/wgpu-hal/src/vulkan/conv.rs b/wgpu-hal/src/vulkan/conv.rs index 5177896ac5..5d067c0ef0 100644 --- a/wgpu-hal/src/vulkan/conv.rs +++ b/wgpu-hal/src/vulkan/conv.rs @@ -747,6 +747,12 @@ pub fn map_shader_stage(stage: wgt::ShaderStages) -> vk::ShaderStageFlags { if stage.contains(wgt::ShaderStages::COMPUTE) { flags |= vk::ShaderStageFlags::COMPUTE; } + if stage.contains(wgt::ShaderStages::TASK) { + flags |= vk::ShaderStageFlags::TASK_EXT; + } + if stage.contains(wgt::ShaderStages::MESH) { + flags |= vk::ShaderStageFlags::MESH_EXT; + } flags } diff --git a/wgpu-hal/src/vulkan/device.rs b/wgpu-hal/src/vulkan/device.rs index 6b66bfa6aa..ecaf1f4062 100644 --- a/wgpu-hal/src/vulkan/device.rs +++ b/wgpu-hal/src/vulkan/device.rs @@ -1979,25 +1979,32 @@ impl crate::Device for super::Device { ..Default::default() }; let mut stages = ArrayVec::<_, { crate::MAX_CONCURRENT_SHADER_STAGES }>::new(); - let mut vertex_buffers = Vec::with_capacity(desc.vertex_buffers.len()); + let mut vertex_buffers = Vec::new(); let mut vertex_attributes = Vec::new(); - for (i, vb) in desc.vertex_buffers.iter().enumerate() { - vertex_buffers.push(vk::VertexInputBindingDescription { - binding: i as u32, - stride: vb.array_stride as u32, - input_rate: match vb.step_mode { - wgt::VertexStepMode::Vertex => vk::VertexInputRate::VERTEX, - wgt::VertexStepMode::Instance => vk::VertexInputRate::INSTANCE, - }, - }); - for at in vb.attributes { - vertex_attributes.push(vk::VertexInputAttributeDescription { - location: at.shader_location, + if let crate::VertexProcessor::Standard { + vertex_buffers: desc_vertex_buffers, + vertex_stage: _, + } = &desc.vertex_processor + { + vertex_buffers = Vec::with_capacity(desc_vertex_buffers.len()); + for (i, vb) in desc_vertex_buffers.iter().enumerate() { + vertex_buffers.push(vk::VertexInputBindingDescription { binding: i as u32, - format: conv::map_vertex_format(at.format), - offset: at.offset as u32, + stride: vb.array_stride as u32, + input_rate: match vb.step_mode { + wgt::VertexStepMode::Vertex => vk::VertexInputRate::VERTEX, + wgt::VertexStepMode::Instance => vk::VertexInputRate::INSTANCE, + }, }); + for at in vb.attributes { + vertex_attributes.push(vk::VertexInputAttributeDescription { + location: at.shader_location, + binding: i as u32, + format: conv::map_vertex_format(at.format), + offset: at.offset as u32, + }); + } } } @@ -2009,12 +2016,41 @@ impl crate::Device for super::Device { .topology(conv::map_topology(desc.primitive.topology)) .primitive_restart_enable(desc.primitive.strip_index_format.is_some()); - let compiled_vs = self.compile_stage( - &desc.vertex_stage, - naga::ShaderStage::Vertex, - &desc.layout.binding_arrays, - )?; - stages.push(compiled_vs.create_info); + let mut compiled_vs = None; + let mut compiled_ms = None; + let mut compiled_ts = None; + match &desc.vertex_processor { + crate::VertexProcessor::Standard { + vertex_buffers: _, + vertex_stage, + } => { + compiled_vs = Some(self.compile_stage( + vertex_stage, + naga::ShaderStage::Vertex, + &desc.layout.binding_arrays, + )?); + stages.push(compiled_vs.as_ref().unwrap().create_info); + } + crate::VertexProcessor::Mesh { + task_stage, + mesh_stage, + } => { + if let Some(t) = task_stage.as_ref() { + compiled_ts = Some(self.compile_stage( + t, + naga::ShaderStage::Task, + &desc.layout.binding_arrays, + )?); + stages.push(compiled_ts.as_ref().unwrap().create_info); + } + compiled_ms = Some(self.compile_stage( + mesh_stage, + naga::ShaderStage::Mesh, + &desc.layout.binding_arrays, + )?); + stages.push(compiled_ms.as_ref().unwrap().create_info); + } + } let compiled_fs = match desc.fragment_stage { Some(ref stage) => { let compiled = self.compile_stage( @@ -2177,228 +2213,13 @@ impl crate::Device for super::Device { unsafe { self.shared.set_object_name(raw, label) }; } - if let Some(raw_module) = compiled_vs.temp_raw_module { - unsafe { self.shared.raw.destroy_shader_module(raw_module, None) }; - } if let Some(CompiledStage { temp_raw_module: Some(raw_module), .. - }) = compiled_fs + }) = compiled_vs { unsafe { self.shared.raw.destroy_shader_module(raw_module, None) }; } - - self.counters.render_pipelines.add(1); - - Ok(super::RenderPipeline { raw }) - } - unsafe fn create_mesh_pipeline( - &self, - desc: &crate::MeshPipelineDescriptor< - ::PipelineLayout, - ::ShaderModule, - ::PipelineCache, - >, - ) -> Result<::RenderPipeline, crate::PipelineError> { - let dynamic_states = [ - vk::DynamicState::VIEWPORT, - vk::DynamicState::SCISSOR, - vk::DynamicState::BLEND_CONSTANTS, - vk::DynamicState::STENCIL_REFERENCE, - ]; - let mut compatible_rp_key = super::RenderPassKey { - sample_count: desc.multisample.count, - multiview: desc.multiview, - ..Default::default() - }; - let mut stages = ArrayVec::<_, { crate::MAX_CONCURRENT_SHADER_STAGES }>::new(); - - let vk_input_assembly = vk::PipelineInputAssemblyStateCreateInfo::default() - .topology(conv::map_topology(desc.primitive.topology)) - .primitive_restart_enable(desc.primitive.strip_index_format.is_some()); - - let compiled_ts = match desc.task_stage { - Some(ref stage) => { - let mut compiled = self.compile_stage( - stage, - naga::ShaderStage::Task, - &desc.layout.binding_arrays, - )?; - compiled.create_info.stage = vk::ShaderStageFlags::TASK_EXT; - stages.push(compiled.create_info); - Some(compiled) - } - None => None, - }; - - let mut compiled_ms = self.compile_stage( - &desc.mesh_stage, - naga::ShaderStage::Mesh, - &desc.layout.binding_arrays, - )?; - compiled_ms.create_info.stage = vk::ShaderStageFlags::MESH_EXT; - stages.push(compiled_ms.create_info); - let compiled_fs = match desc.fragment_stage { - Some(ref stage) => { - let compiled = self.compile_stage( - stage, - naga::ShaderStage::Fragment, - &desc.layout.binding_arrays, - )?; - stages.push(compiled.create_info); - Some(compiled) - } - None => None, - }; - - let mut vk_rasterization = vk::PipelineRasterizationStateCreateInfo::default() - .polygon_mode(conv::map_polygon_mode(desc.primitive.polygon_mode)) - .front_face(conv::map_front_face(desc.primitive.front_face)) - .line_width(1.0) - .depth_clamp_enable(desc.primitive.unclipped_depth); - if let Some(face) = desc.primitive.cull_mode { - vk_rasterization = vk_rasterization.cull_mode(conv::map_cull_face(face)) - } - let mut vk_rasterization_conservative_state = - vk::PipelineRasterizationConservativeStateCreateInfoEXT::default() - .conservative_rasterization_mode( - vk::ConservativeRasterizationModeEXT::OVERESTIMATE, - ); - if desc.primitive.conservative { - vk_rasterization = vk_rasterization.push_next(&mut vk_rasterization_conservative_state); - } - - let mut vk_depth_stencil = vk::PipelineDepthStencilStateCreateInfo::default(); - if let Some(ref ds) = desc.depth_stencil { - let vk_format = self.shared.private_caps.map_texture_format(ds.format); - let vk_layout = if ds.is_read_only(desc.primitive.cull_mode) { - vk::ImageLayout::DEPTH_STENCIL_READ_ONLY_OPTIMAL - } else { - vk::ImageLayout::DEPTH_STENCIL_ATTACHMENT_OPTIMAL - }; - compatible_rp_key.depth_stencil = Some(super::DepthStencilAttachmentKey { - base: super::AttachmentKey::compatible(vk_format, vk_layout), - stencil_ops: crate::AttachmentOps::all(), - }); - - if ds.is_depth_enabled() { - vk_depth_stencil = vk_depth_stencil - .depth_test_enable(true) - .depth_write_enable(ds.depth_write_enabled) - .depth_compare_op(conv::map_comparison(ds.depth_compare)); - } - if ds.stencil.is_enabled() { - let s = &ds.stencil; - let front = conv::map_stencil_face(&s.front, s.read_mask, s.write_mask); - let back = conv::map_stencil_face(&s.back, s.read_mask, s.write_mask); - vk_depth_stencil = vk_depth_stencil - .stencil_test_enable(true) - .front(front) - .back(back); - } - - if ds.bias.is_enabled() { - vk_rasterization = vk_rasterization - .depth_bias_enable(true) - .depth_bias_constant_factor(ds.bias.constant as f32) - .depth_bias_clamp(ds.bias.clamp) - .depth_bias_slope_factor(ds.bias.slope_scale); - } - } - - let vk_viewport = vk::PipelineViewportStateCreateInfo::default() - .flags(vk::PipelineViewportStateCreateFlags::empty()) - .scissor_count(1) - .viewport_count(1); - - let vk_sample_mask = [ - desc.multisample.mask as u32, - (desc.multisample.mask >> 32) as u32, - ]; - let vk_multisample = vk::PipelineMultisampleStateCreateInfo::default() - .rasterization_samples(vk::SampleCountFlags::from_raw(desc.multisample.count)) - .alpha_to_coverage_enable(desc.multisample.alpha_to_coverage_enabled) - .sample_mask(&vk_sample_mask); - - let mut vk_attachments = Vec::with_capacity(desc.color_targets.len()); - for cat in desc.color_targets { - let (key, attarchment) = if let Some(cat) = cat.as_ref() { - let mut vk_attachment = vk::PipelineColorBlendAttachmentState::default() - .color_write_mask(vk::ColorComponentFlags::from_raw(cat.write_mask.bits())); - if let Some(ref blend) = cat.blend { - let (color_op, color_src, color_dst) = conv::map_blend_component(&blend.color); - let (alpha_op, alpha_src, alpha_dst) = conv::map_blend_component(&blend.alpha); - vk_attachment = vk_attachment - .blend_enable(true) - .color_blend_op(color_op) - .src_color_blend_factor(color_src) - .dst_color_blend_factor(color_dst) - .alpha_blend_op(alpha_op) - .src_alpha_blend_factor(alpha_src) - .dst_alpha_blend_factor(alpha_dst); - } - - let vk_format = self.shared.private_caps.map_texture_format(cat.format); - ( - Some(super::ColorAttachmentKey { - base: super::AttachmentKey::compatible( - vk_format, - vk::ImageLayout::COLOR_ATTACHMENT_OPTIMAL, - ), - resolve: None, - }), - vk_attachment, - ) - } else { - (None, vk::PipelineColorBlendAttachmentState::default()) - }; - - compatible_rp_key.colors.push(key); - vk_attachments.push(attarchment); - } - - let vk_color_blend = - vk::PipelineColorBlendStateCreateInfo::default().attachments(&vk_attachments); - - let vk_dynamic_state = - vk::PipelineDynamicStateCreateInfo::default().dynamic_states(&dynamic_states); - - let raw_pass = self.shared.make_render_pass(compatible_rp_key)?; - - let vk_infos = [{ - vk::GraphicsPipelineCreateInfo::default() - .layout(desc.layout.raw) - .stages(&stages) - .input_assembly_state(&vk_input_assembly) - .rasterization_state(&vk_rasterization) - .viewport_state(&vk_viewport) - .multisample_state(&vk_multisample) - .depth_stencil_state(&vk_depth_stencil) - .color_blend_state(&vk_color_blend) - .dynamic_state(&vk_dynamic_state) - .render_pass(raw_pass) - }]; - - let pipeline_cache = desc - .cache - .map(|it| it.raw) - .unwrap_or(vk::PipelineCache::null()); - - let mut raw_vec = { - profiling::scope!("vkCreateGraphicsPipelines"); - unsafe { - self.shared - .raw - .create_graphics_pipelines(pipeline_cache, &vk_infos, None) - .map_err(|(_, e)| super::map_pipeline_err(e)) - }? - }; - - let raw = raw_vec.pop().unwrap(); - if let Some(label) = desc.label { - unsafe { self.shared.set_object_name(raw, label) }; - } - // NOTE: this could leak shaders in case of an error. if let Some(CompiledStage { temp_raw_module: Some(raw_module), .. @@ -2406,7 +2227,11 @@ impl crate::Device for super::Device { { unsafe { self.shared.raw.destroy_shader_module(raw_module, None) }; } - if let Some(raw_module) = compiled_ms.temp_raw_module { + if let Some(CompiledStage { + temp_raw_module: Some(raw_module), + .. + }) = compiled_ms + { unsafe { self.shared.raw.destroy_shader_module(raw_module, None) }; } if let Some(CompiledStage { diff --git a/wgpu-info/src/human.rs b/wgpu-info/src/human.rs index d56cc324dc..f0930bd021 100644 --- a/wgpu-info/src/human.rs +++ b/wgpu-info/src/human.rs @@ -160,6 +160,12 @@ fn print_adapter(output: &mut impl io::Write, report: &AdapterReport, idx: usize max_subgroup_size, max_push_constant_size, max_non_sampler_bindings, + + max_task_workgroup_total_count, + max_task_workgroups_per_dimension, + max_mesh_multiview_count, + max_mesh_output_layers, + max_blas_primitive_count, max_blas_geometry_count, max_tlas_instance_count, @@ -200,6 +206,12 @@ fn print_adapter(output: &mut impl io::Write, report: &AdapterReport, idx: usize writeln!(output, "\t\t Max Compute Workgroup Size Y: {max_compute_workgroup_size_y}")?; writeln!(output, "\t\t Max Compute Workgroup Size Z: {max_compute_workgroup_size_z}")?; writeln!(output, "\t\t Max Compute Workgroups Per Dimension: {max_compute_workgroups_per_dimension}")?; + + writeln!(output, "\t\t Max Task Workgroup Total Count: {max_task_workgroup_total_count}")?; + writeln!(output, "\t\t Max Task Workgroups Per Dimension: {max_task_workgroups_per_dimension}")?; + writeln!(output, "\t\t Max Mesh Multiview Count: {max_mesh_multiview_count}")?; + writeln!(output, "\t\t Max Mesh Output Layers: {max_mesh_output_layers}")?; + writeln!(output, "\t\t Max BLAS Primitive count: {max_blas_primitive_count}")?; writeln!(output, "\t\t Max BLAS Geometry count: {max_blas_geometry_count}")?; writeln!(output, "\t\t Max TLAS Instance count: {max_tlas_instance_count}")?; diff --git a/wgpu-types/src/lib.rs b/wgpu-types/src/lib.rs index 6b51d50464..573028759d 100644 --- a/wgpu-types/src/lib.rs +++ b/wgpu-types/src/lib.rs @@ -615,6 +615,17 @@ pub struct Limits { /// This limit only affects the d3d12 backend. Using a large number will allow the device /// to create many bind groups at the cost of a large up-front allocation at device creation. pub max_non_sampler_bindings: u32, + + /// The maximum total value of x*y*z for a given `draw_mesh_tasks` command + pub max_task_workgroup_total_count: u32, + /// The maximum value for each dimension of a `RenderPass::draw_mesh_tasks(x, y, z)` operation. + /// Defaults to 65535. Higher is "better". + pub max_task_workgroups_per_dimension: u32, + /// The maximum number of layers that can be output from a mesh shader + pub max_mesh_output_layers: u32, + /// The maximum number of views that can be used by a mesh shader + pub max_mesh_multiview_count: u32, + /// The maximum number of primitive (ex: triangles, aabbs) a BLAS is allowed to have. Requesting /// more than 0 during device creation only makes sense if [`Features::EXPERIMENTAL_RAY_QUERY`] /// is enabled. @@ -683,6 +694,10 @@ impl Limits { /// max_subgroup_size: 0, /// max_push_constant_size: 0, /// max_non_sampler_bindings: 1_000_000, + /// max_task_workgroup_total_count: 0, + /// max_task_workgroups_per_dimension: 0, + /// max_mesh_multiview_count: 0, + /// max_mesh_output_layers: 0, /// max_blas_primitive_count: 0, /// max_blas_geometry_count: 0, /// max_tlas_instance_count: 0, @@ -731,6 +746,12 @@ impl Limits { max_subgroup_size: 0, max_push_constant_size: 0, max_non_sampler_bindings: 1_000_000, + + max_task_workgroup_total_count: 0, + max_task_workgroups_per_dimension: 0, + max_mesh_multiview_count: 0, + max_mesh_output_layers: 0, + max_blas_primitive_count: 0, max_blas_geometry_count: 0, max_tlas_instance_count: 0, @@ -780,6 +801,12 @@ impl Limits { /// max_compute_workgroups_per_dimension: 65535, /// max_buffer_size: 256 << 20, // (256 MiB) /// max_non_sampler_bindings: 1_000_000, + /// + /// max_task_workgroup_total_count: 0, + /// max_task_workgroups_per_dimension: 0, + /// max_mesh_multiview_count: 0, + /// max_mesh_output_layers: 0, + /// /// max_blas_primitive_count: 0, /// max_blas_geometry_count: 0, /// max_tlas_instance_count: 0, @@ -797,6 +824,11 @@ impl Limits { max_color_attachments: 4, // see: https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf#page=7 max_compute_workgroup_storage_size: 16352, + + max_task_workgroups_per_dimension: 0, + max_task_workgroup_total_count: 0, + max_mesh_multiview_count: 0, + max_mesh_output_layers: 0, ..Self::defaults() } } @@ -844,6 +876,12 @@ impl Limits { /// max_compute_workgroups_per_dimension: 0, // + /// max_buffer_size: 256 << 20, // (256 MiB), /// max_non_sampler_bindings: 1_000_000, + /// + /// max_task_workgroup_total_count: 0, + /// max_task_workgroups_per_dimension: 0, + /// max_mesh_multiview_count: 0, + /// max_mesh_output_layers: 0, + /// /// max_blas_primitive_count: 0, /// max_blas_geometry_count: 0, /// max_tlas_instance_count: 0, @@ -929,6 +967,24 @@ impl Limits { } } + /// The recommended minimum limits for mesh shaders if you enable [`Features::EXPERIMENTAL_MESH_SHADER`] + /// + /// These are chosen somewhat arbitrarily. They are small enough that they should cover all physical devices, + /// but not necessarily all use cases. + #[must_use] + pub const fn using_recommended_minimum_mesh_shader_values(self) -> Self { + Self { + // Literally just made this up as 256^2 or 2^16. + // My GPU supports 2^22, and compute shaders don't have this kind of limit. + // This very likely is never a real limiter + max_task_workgroup_total_count: 65536, + max_task_workgroups_per_dimension: 256, + max_mesh_multiview_count: 1, + max_mesh_output_layers: 1024, + ..self + } + } + /// Compares every limits within self is within the limits given in `allowed`. /// /// If you need detailed information on failures, look at [`Limits::check_limits_with_fail_fn`]. @@ -1008,6 +1064,12 @@ impl Limits { } compare!(max_push_constant_size, Less); compare!(max_non_sampler_bindings, Less); + + compare!(max_task_workgroup_total_count, Less); + compare!(max_task_workgroups_per_dimension, Less); + compare!(max_mesh_multiview_count, Less); + compare!(max_mesh_output_layers, Less); + compare!(max_blas_primitive_count, Less); compare!(max_blas_geometry_count, Less); compare!(max_tlas_instance_count, Less); @@ -1402,9 +1464,9 @@ bitflags::bitflags! { const COMPUTE = 1 << 2; /// Binding is visible from the vertex and fragment shaders of a render pipeline. const VERTEX_FRAGMENT = Self::VERTEX.bits() | Self::FRAGMENT.bits(); - /// Binding is visible from the task shader of a mesh pipeline + /// Binding is visible from the task shader of a mesh pipeline. const TASK = 1 << 3; - /// Binding is visible from the mesh shader of a mesh pipeline + /// Binding is visible from the mesh shader of a mesh pipeline. const MESH = 1 << 4; } } diff --git a/wgpu/src/api/device.rs b/wgpu/src/api/device.rs index 2a192e2b3b..401431bcbd 100644 --- a/wgpu/src/api/device.rs +++ b/wgpu/src/api/device.rs @@ -245,6 +245,13 @@ impl Device { RenderPipeline { inner: pipeline } } + /// Creates a mesh shader based [`RenderPipeline`]. + #[must_use] + pub fn create_mesh_pipeline(&self, desc: &MeshPipelineDescriptor<'_>) -> RenderPipeline { + let pipeline = self.inner.create_mesh_pipeline(desc); + RenderPipeline { inner: pipeline } + } + /// Creates a [`ComputePipeline`]. #[must_use] pub fn create_compute_pipeline(&self, desc: &ComputePipelineDescriptor<'_>) -> ComputePipeline { diff --git a/wgpu/src/api/render_pass.rs b/wgpu/src/api/render_pass.rs index 0092c2af74..8163b4261f 100644 --- a/wgpu/src/api/render_pass.rs +++ b/wgpu/src/api/render_pass.rs @@ -226,6 +226,12 @@ impl RenderPass<'_> { self.inner.draw_indexed(indices, base_vertex, instances); } + /// Draws using a mesh shader pipeline + pub fn draw_mesh_tasks(&mut self, group_count_x: u32, group_count_y: u32, group_count_z: u32) { + self.inner + .draw_mesh_tasks(group_count_x, group_count_y, group_count_z); + } + /// Draws primitives from the active vertex buffer(s) based on the contents of the `indirect_buffer`. /// /// This is like calling [`RenderPass::draw`] but the contents of the call are specified in the `indirect_buffer`. @@ -249,6 +255,25 @@ impl RenderPass<'_> { .draw_indexed_indirect(&indirect_buffer.inner, indirect_offset); } + /// Draws using a mesh shader pipeline, + /// based on the contents of the `indirect_buffer` + /// + /// This is like calling [`RenderPass::draw_mesh_tasks`] but the contents of the call are specified in the `indirect_buffer`. + /// The structure expected in the `indirect_buffer` must conform to [`DispatchIndirectArgs`](crate::util::DispatchIndirectArgs). + /// + /// Indirect drawing has some caveats depending on the features available. We are not currently able to validate + /// these and issue an error. + /// + /// See details on the individual flags for more information. + pub fn draw_mesh_tasks_indirect( + &mut self, + indirect_buffer: &Buffer, + indirect_offset: BufferAddress, + ) { + self.inner + .draw_mesh_tasks_indirect(&indirect_buffer.inner, indirect_offset); + } + /// Execute a [render bundle][RenderBundle], which is a set of pre-recorded commands /// that can be run together. /// @@ -307,6 +332,23 @@ impl RenderPass<'_> { .multi_draw_indexed_indirect(&indirect_buffer.inner, indirect_offset, count); } + /// Dispatches multiple draw calls based on the contents of the `indirect_buffer`. + /// `count` draw calls are issued. + /// + /// The structure expected in the `indirect_buffer` must conform to [`DispatchIndirectArgs`](crate::util::DispatchIndirectArgs). + /// + /// This drawing command uses the current render state, as set by preceding `set_*()` methods. + /// It is not affected by changes to the state that are performed after it is called. + pub fn multi_draw_mesh_tasks_indirect( + &mut self, + indirect_buffer: &Buffer, + indirect_offset: BufferAddress, + count: u32, + ) { + self.inner + .multi_draw_mesh_tasks_indirect(&indirect_buffer.inner, indirect_offset, count); + } + #[cfg(custom)] /// Returns custom implementation of RenderPass (if custom backend and is internally T) pub fn as_custom(&self) -> Option<&T> { @@ -395,6 +437,34 @@ impl RenderPass<'_> { max_count, ); } + + /// Dispatches multiple draw calls based on the contents of the `indirect_buffer`. The count buffer is read to determine how many draws to issue. + /// + /// The indirect buffer must be long enough to account for `max_count` draws, however only `count` + /// draws will be read. If `count` is greater than `max_count`, `max_count` will be used. + /// + /// The structure expected in the `indirect_buffer` must conform to [`DispatchIndirectArgs`](crate::util::DispatchIndirectArgs). + /// + /// These draw structures are expected to be tightly packed. + /// + /// This drawing command uses the current render state, as set by preceding `set_*()` methods. + /// It is not affected by changes to the state that are performed after it is called. + pub fn multi_draw_mesh_tasks_indirect_count( + &mut self, + indirect_buffer: &Buffer, + indirect_offset: BufferAddress, + count_buffer: &Buffer, + count_offset: BufferAddress, + max_count: u32, + ) { + self.inner.multi_draw_mesh_tasks_indirect_count( + &indirect_buffer.inner, + indirect_offset, + &count_buffer.inner, + count_offset, + max_count, + ); + } } /// [`Features::PUSH_CONSTANTS`] must be enabled on the device in order to call these functions. diff --git a/wgpu/src/api/render_pipeline.rs b/wgpu/src/api/render_pipeline.rs index 27f96bce73..e887bb4b97 100644 --- a/wgpu/src/api/render_pipeline.rs +++ b/wgpu/src/api/render_pipeline.rs @@ -145,6 +145,48 @@ pub struct FragmentState<'a> { #[cfg(send_sync)] static_assertions::assert_impl_all!(FragmentState<'_>: Send, Sync); +/// Describes the task shader stage in a mesh shader pipeline. +/// +/// For use in [`MeshPipelineDescriptor`] +#[derive(Clone, Debug)] +pub struct TaskState<'a> { + /// The compiled shader module for this stage. + pub module: &'a ShaderModule, + /// The name of the entry point in the compiled shader to use. + /// + /// If [`Some`], there must be a vertex-stage shader entry point with this name in `module`. + /// Otherwise, expect exactly one vertex-stage entry point in `module`, which will be + /// selected. + pub entry_point: Option<&'a str>, + /// Advanced options for when this pipeline is compiled + /// + /// This implements `Default`, and for most users can be set to `Default::default()` + pub compilation_options: PipelineCompilationOptions<'a>, +} +#[cfg(send_sync)] +static_assertions::assert_impl_all!(TaskState<'_>: Send, Sync); + +/// Describes the mesh shader stage in a mesh shader pipeline. +/// +/// For use in [`MeshPipelineDescriptor`] +#[derive(Clone, Debug)] +pub struct MeshState<'a> { + /// The compiled shader module for this stage. + pub module: &'a ShaderModule, + /// The name of the entry point in the compiled shader to use. + /// + /// If [`Some`], there must be a vertex-stage shader entry point with this name in `module`. + /// Otherwise, expect exactly one vertex-stage entry point in `module`, which will be + /// selected. + pub entry_point: Option<&'a str>, + /// Advanced options for when this pipeline is compiled + /// + /// This implements `Default`, and for most users can be set to `Default::default()` + pub compilation_options: PipelineCompilationOptions<'a>, +} +#[cfg(send_sync)] +static_assertions::assert_impl_all!(MeshState<'_>: Send, Sync); + /// Describes a render (graphics) pipeline. /// /// For use with [`Device::create_render_pipeline`]. @@ -193,3 +235,51 @@ pub struct RenderPipelineDescriptor<'a> { } #[cfg(send_sync)] static_assertions::assert_impl_all!(RenderPipelineDescriptor<'_>: Send, Sync); + +/// Describes a mesh shader (graphics) pipeline. +/// +/// For use with [`Device::create_mesh_pipeline`]. +#[derive(Clone, Debug)] +pub struct MeshPipelineDescriptor<'a> { + /// Debug label of the pipeline. This will show up in graphics debuggers for easy identification. + pub label: Label<'a>, + /// The layout of bind groups for this pipeline. + /// + /// If this is set, then [`Device::create_render_pipeline`] will raise a validation error if + /// the layout doesn't match what the shader module(s) expect. + /// + /// Using the same [`PipelineLayout`] for many [`RenderPipeline`] or [`ComputePipeline`] + /// pipelines guarantees that you don't have to rebind any resources when switching between + /// those pipelines. + /// + /// ## Default pipeline layout + /// + /// If `layout` is `None`, then the pipeline has a [default layout] created and used instead. + /// The default layout is deduced from the shader modules. + /// + /// You can use [`RenderPipeline::get_bind_group_layout`] to create bind groups for use with the + /// default layout. However, these bind groups cannot be used with any other pipelines. This is + /// convenient for simple pipelines, but using an explicit layout is recommended in most cases. + /// + /// [default layout]: https://www.w3.org/TR/webgpu/#default-pipeline-layout + pub layout: Option<&'a PipelineLayout>, + /// The compiled task stage, its entry point, and the color targets. + pub task: Option>, + /// The compiled mesh stage and its entry point + pub mesh: MeshState<'a>, + /// The properties of the pipeline at the primitive assembly and rasterization level. + pub primitive: PrimitiveState, + /// The effect of draw calls on the depth and stencil aspects of the output target, if any. + pub depth_stencil: Option, + /// The multi-sampling properties of the pipeline. + pub multisample: MultisampleState, + /// The compiled fragment stage, its entry point, and the color targets. + pub fragment: Option>, + /// If the pipeline will be used with a multiview render pass, this indicates how many array + /// layers the attachments will have. + pub multiview: Option, + /// The pipeline cache to use when creating this pipeline. + pub cache: Option<&'a PipelineCache>, +} +#[cfg(send_sync)] +static_assertions::assert_impl_all!(MeshPipelineDescriptor<'_>: Send, Sync); diff --git a/wgpu/src/backend/webgpu.rs b/wgpu/src/backend/webgpu.rs index 36ee0b5541..28fcc99b84 100644 --- a/wgpu/src/backend/webgpu.rs +++ b/wgpu/src/backend/webgpu.rs @@ -823,6 +823,12 @@ fn map_wgt_limits(limits: webgpu_sys::GpuSupportedLimits) -> wgt::Limits { max_push_constant_size: wgt::Limits::default().max_push_constant_size, max_non_sampler_bindings: wgt::Limits::default().max_non_sampler_bindings, max_inter_stage_shader_components: wgt::Limits::default().max_inter_stage_shader_components, + + max_task_workgroup_total_count: wgt::Limits::default().max_task_workgroup_total_count, + max_task_workgroups_per_dimension: wgt::Limits::default().max_task_workgroups_per_dimension, + max_mesh_output_layers: wgt::Limits::default().max_mesh_output_layers, + max_mesh_multiview_count: wgt::Limits::default().max_mesh_multiview_count, + max_blas_primitive_count: wgt::Limits::default().max_blas_primitive_count, max_blas_geometry_count: wgt::Limits::default().max_blas_geometry_count, max_tlas_instance_count: wgt::Limits::default().max_tlas_instance_count, @@ -2174,6 +2180,13 @@ impl dispatch::DeviceInterface for WebDevice { .into() } + fn create_mesh_pipeline( + &self, + _desc: &crate::MeshPipelineDescriptor<'_>, + ) -> dispatch::DispatchRenderPipeline { + panic!("MESH_SHADER feature must be enabled to call create_mesh_pipeline") + } + fn create_compute_pipeline( &self, desc: &crate::ComputePipelineDescriptor<'_>, @@ -3415,6 +3428,10 @@ impl dispatch::RenderPassInterface for WebRenderPassEncoder { ) } + fn draw_mesh_tasks(&mut self, _group_count_x: u32, _group_count_y: u32, _group_count_z: u32) { + panic!("MESH_SHADER feature must be enabled to call draw_mesh_tasks") + } + fn draw_indirect( &mut self, indirect_buffer: &dispatch::DispatchBuffer, @@ -3435,6 +3452,14 @@ impl dispatch::RenderPassInterface for WebRenderPassEncoder { .draw_indexed_indirect_with_f64(&buffer.inner, indirect_offset as f64); } + fn draw_mesh_tasks_indirect( + &mut self, + _indirect_buffer: &dispatch::DispatchBuffer, + _indirect_offset: crate::BufferAddress, + ) { + panic!("MESH_SHADER feature must be enabled to call draw_mesh_tasks_indirect") + } + fn multi_draw_indirect( &mut self, indirect_buffer: &dispatch::DispatchBuffer, @@ -3465,6 +3490,15 @@ impl dispatch::RenderPassInterface for WebRenderPassEncoder { } } + fn multi_draw_mesh_tasks_indirect( + &mut self, + _indirect_buffer: &dispatch::DispatchBuffer, + _indirect_offset: crate::BufferAddress, + _count: u32, + ) { + panic!("MESH_SHADER feature must be enabled to call multi_draw_mesh_tasks_indirect") + } + fn multi_draw_indirect_count( &mut self, _indirect_buffer: &dispatch::DispatchBuffer, @@ -3489,6 +3523,17 @@ impl dispatch::RenderPassInterface for WebRenderPassEncoder { panic!("MULTI_DRAW_INDIRECT_COUNT feature must be enabled to call multi_draw_indexed_indirect_count") } + fn multi_draw_mesh_tasks_indirect_count( + &mut self, + _indirect_buffer: &dispatch::DispatchBuffer, + _indirect_offset: crate::BufferAddress, + _count_buffer: &dispatch::DispatchBuffer, + _count_buffer_offset: crate::BufferAddress, + _max_count: u32, + ) { + panic!("MESH_SHADER feature must be enabled to call multi_draw_mesh_tasks_indirect_count") + } + fn insert_debug_marker(&mut self, _label: &str) { // Not available in gecko yet // self.inner.insert_debug_marker(label); diff --git a/wgpu/src/backend/wgpu_core.rs b/wgpu/src/backend/wgpu_core.rs index 0a5648227c..b97f20870c 100644 --- a/wgpu/src/backend/wgpu_core.rs +++ b/wgpu/src/backend/wgpu_core.rs @@ -1354,6 +1354,102 @@ impl dispatch::DeviceInterface for CoreDevice { .into() } + fn create_mesh_pipeline( + &self, + desc: &crate::MeshPipelineDescriptor<'_>, + ) -> dispatch::DispatchRenderPipeline { + use wgc::pipeline as pipe; + + let mesh_constants = desc + .mesh + .compilation_options + .constants + .iter() + .map(|&(key, value)| (String::from(key), value)) + .collect(); + let descriptor = pipe::MeshPipelineDescriptor { + label: desc.label.map(Borrowed), + task: desc.task.as_ref().map(|task| { + let task_constants = task + .compilation_options + .constants + .iter() + .map(|&(key, value)| (String::from(key), value)) + .collect(); + pipe::TaskState { + stage: pipe::ProgrammableStageDescriptor { + module: task.module.inner.as_core().id, + entry_point: task.entry_point.map(Borrowed), + constants: task_constants, + zero_initialize_workgroup_memory: desc + .mesh + .compilation_options + .zero_initialize_workgroup_memory, + }, + } + }), + mesh: pipe::MeshState { + stage: pipe::ProgrammableStageDescriptor { + module: desc.mesh.module.inner.as_core().id, + entry_point: desc.mesh.entry_point.map(Borrowed), + constants: mesh_constants, + zero_initialize_workgroup_memory: desc + .mesh + .compilation_options + .zero_initialize_workgroup_memory, + }, + }, + layout: desc.layout.map(|layout| layout.inner.as_core().id), + primitive: desc.primitive, + depth_stencil: desc.depth_stencil.clone(), + multisample: desc.multisample, + fragment: desc.fragment.as_ref().map(|frag| { + let frag_constants = frag + .compilation_options + .constants + .iter() + .map(|&(key, value)| (String::from(key), value)) + .collect(); + pipe::FragmentState { + stage: pipe::ProgrammableStageDescriptor { + module: frag.module.inner.as_core().id, + entry_point: frag.entry_point.map(Borrowed), + constants: frag_constants, + zero_initialize_workgroup_memory: frag + .compilation_options + .zero_initialize_workgroup_memory, + }, + targets: Borrowed(frag.targets), + } + }), + multiview: desc.multiview, + cache: desc.cache.map(|cache| cache.inner.as_core().id), + }; + + let (id, error) = self + .context + .0 + .device_create_mesh_pipeline(self.id, &descriptor, None); + if let Some(cause) = error { + if let wgc::pipeline::CreateRenderPipelineError::Internal { stage, ref error } = cause { + log::error!("Shader translation error for stage {stage:?}: {error}"); + log::error!("Please report it to https://github.com/gfx-rs/wgpu"); + } + self.context.handle_error( + &self.error_sink, + cause, + desc.label, + "Device::create_render_pipeline", + ); + } + CoreRenderPipeline { + context: self.context.clone(), + id, + error_sink: Arc::clone(&self.error_sink), + } + .into() + } + fn create_compute_pipeline( &self, desc: &crate::ComputePipelineDescriptor<'_>, @@ -3125,6 +3221,22 @@ impl dispatch::RenderPassInterface for CoreRenderPass { } } + fn draw_mesh_tasks(&mut self, group_count_x: u32, group_count_y: u32, group_count_z: u32) { + if let Err(cause) = self.context.0.render_pass_draw_mesh_tasks( + &mut self.pass, + group_count_x, + group_count_y, + group_count_z, + ) { + self.context.handle_error( + &self.error_sink, + cause, + self.pass.label(), + "RenderPass::draw_mesh_tasks", + ); + } + } + fn draw_indirect( &mut self, indirect_buffer: &dispatch::DispatchBuffer, @@ -3167,6 +3279,27 @@ impl dispatch::RenderPassInterface for CoreRenderPass { } } + fn draw_mesh_tasks_indirect( + &mut self, + indirect_buffer: &dispatch::DispatchBuffer, + indirect_offset: crate::BufferAddress, + ) { + let indirect_buffer = indirect_buffer.as_core(); + + if let Err(cause) = self.context.0.render_pass_draw_mesh_tasks_indirect( + &mut self.pass, + indirect_buffer.id, + indirect_offset, + ) { + self.context.handle_error( + &self.error_sink, + cause, + self.pass.label(), + "RenderPass::draw_mesh_tasks_indirect", + ); + } + } + fn multi_draw_indirect( &mut self, indirect_buffer: &dispatch::DispatchBuffer, @@ -3213,6 +3346,29 @@ impl dispatch::RenderPassInterface for CoreRenderPass { } } + fn multi_draw_mesh_tasks_indirect( + &mut self, + indirect_buffer: &dispatch::DispatchBuffer, + indirect_offset: crate::BufferAddress, + count: u32, + ) { + let indirect_buffer = indirect_buffer.as_core(); + + if let Err(cause) = self.context.0.render_pass_multi_draw_mesh_tasks_indirect( + &mut self.pass, + indirect_buffer.id, + indirect_offset, + count, + ) { + self.context.handle_error( + &self.error_sink, + cause, + self.pass.label(), + "RenderPass::multi_draw_mesh_tasks_indirect", + ); + } + } + fn multi_draw_indirect_count( &mut self, indirect_buffer: &dispatch::DispatchBuffer, @@ -3273,6 +3429,38 @@ impl dispatch::RenderPassInterface for CoreRenderPass { } } + fn multi_draw_mesh_tasks_indirect_count( + &mut self, + indirect_buffer: &dispatch::DispatchBuffer, + indirect_offset: crate::BufferAddress, + count_buffer: &dispatch::DispatchBuffer, + count_buffer_offset: crate::BufferAddress, + max_count: u32, + ) { + let indirect_buffer = indirect_buffer.as_core(); + let count_buffer = count_buffer.as_core(); + + if let Err(cause) = self + .context + .0 + .render_pass_multi_draw_mesh_tasks_indirect_count( + &mut self.pass, + indirect_buffer.id, + indirect_offset, + count_buffer.id, + count_buffer_offset, + max_count, + ) + { + self.context.handle_error( + &self.error_sink, + cause, + self.pass.label(), + "RenderPass::multi_draw_mesh_tasks_indirect_count", + ); + } + } + fn insert_debug_marker(&mut self, label: &str) { if let Err(cause) = self .context diff --git a/wgpu/src/dispatch.rs b/wgpu/src/dispatch.rs index f7af8c3751..3b549a659e 100644 --- a/wgpu/src/dispatch.rs +++ b/wgpu/src/dispatch.rs @@ -147,6 +147,10 @@ pub trait DeviceInterface: CommonTraits { &self, desc: &crate::RenderPipelineDescriptor<'_>, ) -> DispatchRenderPipeline; + fn create_mesh_pipeline( + &self, + desc: &crate::MeshPipelineDescriptor<'_>, + ) -> DispatchRenderPipeline; fn create_compute_pipeline( &self, desc: &crate::ComputePipelineDescriptor<'_>, @@ -420,6 +424,7 @@ pub trait RenderPassInterface: CommonTraits { fn draw(&mut self, vertices: Range, instances: Range); fn draw_indexed(&mut self, indices: Range, base_vertex: i32, instances: Range); + fn draw_mesh_tasks(&mut self, group_count_x: u32, group_count_y: u32, group_count_z: u32); fn draw_indirect( &mut self, indirect_buffer: &DispatchBuffer, @@ -430,6 +435,11 @@ pub trait RenderPassInterface: CommonTraits { indirect_buffer: &DispatchBuffer, indirect_offset: crate::BufferAddress, ); + fn draw_mesh_tasks_indirect( + &mut self, + indirect_buffer: &DispatchBuffer, + indirect_offset: crate::BufferAddress, + ); fn multi_draw_indirect( &mut self, @@ -451,6 +461,12 @@ pub trait RenderPassInterface: CommonTraits { count_buffer_offset: crate::BufferAddress, max_count: u32, ); + fn multi_draw_mesh_tasks_indirect( + &mut self, + indirect_buffer: &DispatchBuffer, + indirect_offset: crate::BufferAddress, + count: u32, + ); fn multi_draw_indexed_indirect_count( &mut self, indirect_buffer: &DispatchBuffer, @@ -459,6 +475,14 @@ pub trait RenderPassInterface: CommonTraits { count_buffer_offset: crate::BufferAddress, max_count: u32, ); + fn multi_draw_mesh_tasks_indirect_count( + &mut self, + indirect_buffer: &DispatchBuffer, + indirect_offset: crate::BufferAddress, + count_buffer: &DispatchBuffer, + count_buffer_offset: crate::BufferAddress, + max_count: u32, + ); fn insert_debug_marker(&mut self, label: &str); fn push_debug_group(&mut self, group_label: &str);