diff --git a/wgpu-core/src/command/compute.rs b/wgpu-core/src/command/compute.rs index 96a1d4488b..c80daf3236 100644 --- a/wgpu-core/src/command/compute.rs +++ b/wgpu-core/src/command/compute.rs @@ -604,68 +604,7 @@ impl Global { } ArcComputeCommand::SetPipeline(pipeline) => { let scope = PassErrorScope::SetPipelineCompute(pipeline.as_info().id()); - - pipeline.same_device_as(cmd_buf).map_pass_err(scope)?; - - state.pipeline = Some(pipeline.as_info().id()); - - let pipeline = state.tracker.compute_pipelines.insert_single(pipeline); - - unsafe { - raw.set_compute_pipeline(pipeline.raw()); - } - - // Rebind resources - if state.binder.pipeline_layout.is_none() - || !state - .binder - .pipeline_layout - .as_ref() - .unwrap() - .is_equal(&pipeline.layout) - { - let (start_index, entries) = state.binder.change_pipeline_layout( - &pipeline.layout, - &pipeline.late_sized_buffer_groups, - ); - if !entries.is_empty() { - for (i, e) in entries.iter().enumerate() { - if let Some(group) = e.group.as_ref() { - let raw_bg = - group.try_raw(&state.snatch_guard).map_pass_err(scope)?; - unsafe { - raw.set_bind_group( - pipeline.layout.raw(), - start_index as u32 + i as u32, - raw_bg, - &e.dynamic_offsets, - ); - } - } - } - } - - // Clear push constant ranges - let non_overlapping = super::bind::compute_nonoverlapping_ranges( - &pipeline.layout.push_constant_ranges, - ); - for range in non_overlapping { - let offset = range.range.start; - let size_bytes = range.range.end - offset; - super::push_constant_clear( - offset, - size_bytes, - |clear_offset, clear_data| unsafe { - raw.set_push_constants( - pipeline.layout.raw(), - wgt::ShaderStages::COMPUTE, - clear_offset, - clear_data, - ); - }, - ); - } - } + set_pipeline(&mut state, raw, cmd_buf, pipeline).map_pass_err(scope)?; } ArcComputeCommand::SetPushConstant { offset, @@ -987,6 +926,69 @@ fn set_bind_group( Ok(()) } +fn set_pipeline( + state: &mut State, + raw: &mut A::CommandEncoder, + cmd_buf: &CommandBuffer, + pipeline: Arc>, +) -> Result<(), ComputePassErrorInner> { + pipeline.same_device_as(cmd_buf)?; + + state.pipeline = Some(pipeline.as_info().id()); + + let pipeline = state.tracker.compute_pipelines.insert_single(pipeline); + + unsafe { + raw.set_compute_pipeline(pipeline.raw()); + } + + // Rebind resources + if state.binder.pipeline_layout.is_none() + || !state + .binder + .pipeline_layout + .as_ref() + .unwrap() + .is_equal(&pipeline.layout) + { + let (start_index, entries) = state + .binder + .change_pipeline_layout(&pipeline.layout, &pipeline.late_sized_buffer_groups); + if !entries.is_empty() { + for (i, e) in entries.iter().enumerate() { + if let Some(group) = e.group.as_ref() { + let raw_bg = group.try_raw(&state.snatch_guard)?; + unsafe { + raw.set_bind_group( + pipeline.layout.raw(), + start_index as u32 + i as u32, + raw_bg, + &e.dynamic_offsets, + ); + } + } + } + } + + // Clear push constant ranges + let non_overlapping = + super::bind::compute_nonoverlapping_ranges(&pipeline.layout.push_constant_ranges); + for range in non_overlapping { + let offset = range.range.start; + let size_bytes = range.range.end - offset; + super::push_constant_clear(offset, size_bytes, |clear_offset, clear_data| unsafe { + raw.set_push_constants( + pipeline.layout.raw(), + wgt::ShaderStages::COMPUTE, + clear_offset, + clear_data, + ); + }); + } + } + Ok(()) +} + // Recording a compute pass. impl Global { pub fn compute_pass_set_bind_group(