diff --git a/wgpu-core/src/command/compute.rs b/wgpu-core/src/command/compute.rs index 2068942475..bd5523efdc 100644 --- a/wgpu-core/src/command/compute.rs +++ b/wgpu-core/src/command/compute.rs @@ -53,11 +53,6 @@ pub struct ComputePass { // Resource binding dedupe state. current_bind_groups: BindGroupStateChange, current_pipeline: StateChange, - - /// The device that this pass is associated with. - /// - /// Used for quick validation during recording. - device_id: id::DeviceId, } impl ComputePass { @@ -68,10 +63,6 @@ impl ComputePass { timestamp_writes, } = desc; - let device_id = parent - .as_ref() - .map_or(id::DeviceId::dummy(0), |p| p.device.as_info().id()); - Self { base: Some(BasePass::new(label)), parent, @@ -79,8 +70,6 @@ impl ComputePass { current_bind_groups: BindGroupStateChange::new(), current_pipeline: StateChange::new(), - - device_id, } } @@ -593,6 +582,8 @@ impl Global { } => { let scope = PassErrorScope::SetBindGroup(bind_group.as_info().id()); + bind_group.device.same_device(device).map_pass_err(scope)?; + let max_bind_groups = cmd_buf.limits.max_bind_groups; if index >= max_bind_groups { return Err(ComputePassErrorInner::BindGroupIndexOutOfRange { @@ -658,6 +649,8 @@ impl Global { let pipeline_id = pipeline.as_info().id(); let scope = PassErrorScope::SetPipelineCompute(pipeline_id); + pipeline.device.same_device(device).map_pass_err(scope)?; + state.pipeline = Some(pipeline_id); let pipeline = tracker.compute_pipelines.insert_single(pipeline); @@ -797,6 +790,8 @@ impl Global { pipeline: state.pipeline, }; + buffer.device.same_device(device).map_pass_err(scope)?; + state.is_ready().map_pass_err(scope)?; device @@ -890,6 +885,8 @@ impl Global { } => { let scope = PassErrorScope::WriteTimestamp; + query_set.device.same_device(device).map_pass_err(scope)?; + device .require_features(wgt::Features::TIMESTAMP_QUERY_INSIDE_PASSES) .map_pass_err(scope)?; @@ -906,6 +903,8 @@ impl Global { } => { let scope = PassErrorScope::BeginPipelineStatisticsQuery; + query_set.device.same_device(device).map_pass_err(scope)?; + let query_set = tracker.query_sets.insert_single(query_set); validate_and_begin_pipeline_statistics_query( @@ -994,10 +993,6 @@ impl Global { .map_err(|_| ComputePassErrorInner::InvalidBindGroup(index)) .map_pass_err(scope)?; - if bind_group.device.as_info().id() != pass.device_id { - return Err(DeviceError::WrongDevice).map_pass_err(scope); - } - base.commands.push(ArcComputeCommand::SetBindGroup { index, num_dynamic_offsets: offsets.len(), @@ -1016,7 +1011,6 @@ impl Global { let scope = PassErrorScope::SetPipelineCompute(pipeline_id); - let device_id = pass.device_id; let base = pass.base_mut(scope)?; if redundant { // Do redundant early-out **after** checking whether the pass is ended or not. @@ -1031,10 +1025,6 @@ impl Global { .map_err(|_| ComputePassErrorInner::InvalidPipeline(pipeline_id)) .map_pass_err(scope)?; - if pipeline.device.as_info().id() != device_id { - return Err(DeviceError::WrongDevice).map_pass_err(scope); - } - base.commands.push(ArcComputeCommand::SetPipeline(pipeline)); Ok(()) @@ -1108,7 +1098,6 @@ impl Global { indirect: true, pipeline: pass.current_pipeline.last_state, }; - let device_id = pass.device_id; let base = pass.base_mut(scope)?; let buffer = hub @@ -1118,10 +1107,6 @@ impl Global { .map_err(|_| ComputePassErrorInner::InvalidBuffer(buffer_id)) .map_pass_err(scope)?; - if buffer.device.as_info().id() != device_id { - return Err(DeviceError::WrongDevice).map_pass_err(scope); - } - base.commands .push(ArcComputeCommand::::DispatchIndirect { buffer, offset }); @@ -1185,7 +1170,6 @@ impl Global { query_index: u32, ) -> Result<(), ComputePassError> { let scope = PassErrorScope::WriteTimestamp; - let device_id = pass.device_id; let base = pass.base_mut(scope)?; let hub = A::hub(self); @@ -1196,10 +1180,6 @@ impl Global { .map_err(|_| ComputePassErrorInner::InvalidQuerySet(query_set_id)) .map_pass_err(scope)?; - if query_set.device.as_info().id() != device_id { - return Err(DeviceError::WrongDevice).map_pass_err(scope); - } - base.commands.push(ArcComputeCommand::WriteTimestamp { query_set, query_index, @@ -1215,7 +1195,6 @@ impl Global { query_index: u32, ) -> Result<(), ComputePassError> { let scope = PassErrorScope::BeginPipelineStatisticsQuery; - let device_id = pass.device_id; let base = pass.base_mut(scope)?; let hub = A::hub(self); @@ -1226,10 +1205,6 @@ impl Global { .map_err(|_| ComputePassErrorInner::InvalidQuerySet(query_set_id)) .map_pass_err(scope)?; - if query_set.device.as_info().id() != device_id { - return Err(DeviceError::WrongDevice).map_pass_err(scope); - } - base.commands .push(ArcComputeCommand::BeginPipelineStatisticsQuery { query_set,