Files
wgpu/wgpu-core/src/command/compute_command.rs
Andreas Reich edf1a86148 Separate out ComputeCommand id->arc resolve (a step towards no lifetimes on wgpu::ComputePass) (#5432)
* move out compute command to separate module

* introduce ArcComputeCommand

* stateless tracker now returns reference to arc upon insertion

* add insert_merge_single to buffer tracker

* compute pass execution now works internally with an ArcComputeCommand

* compute pass execution now translates Command to ArcCommand ahead of time

* don't clone commands in compute pass execution

* remove doc hiding

* use option insert

* clippy fix

* fix private doc issue

* remove unnecessary copied over doc hide
2024-04-23 07:01:29 +00:00

323 lines
11 KiB
Rust

use std::sync::Arc;
use crate::{
binding_model::BindGroup,
hal_api::HalApi,
id,
pipeline::ComputePipeline,
resource::{Buffer, QuerySet},
};
use super::{ComputePassError, ComputePassErrorInner, PassErrorScope};
#[derive(Clone, Copy, Debug)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum ComputeCommand {
SetBindGroup {
index: u32,
num_dynamic_offsets: usize,
bind_group_id: id::BindGroupId,
},
SetPipeline(id::ComputePipelineId),
/// Set a range of push constants to values stored in `push_constant_data`.
SetPushConstant {
/// The byte offset within the push constant storage to write to. This
/// must be a multiple of four.
offset: u32,
/// The number of bytes to write. This must be a multiple of four.
size_bytes: u32,
/// Index in `push_constant_data` of the start of the data
/// to be written.
///
/// Note: this is not a byte offset like `offset`. Rather, it is the
/// index of the first `u32` element in `push_constant_data` to read.
values_offset: u32,
},
Dispatch([u32; 3]),
DispatchIndirect {
buffer_id: id::BufferId,
offset: wgt::BufferAddress,
},
PushDebugGroup {
color: u32,
len: usize,
},
PopDebugGroup,
InsertDebugMarker {
color: u32,
len: usize,
},
WriteTimestamp {
query_set_id: id::QuerySetId,
query_index: u32,
},
BeginPipelineStatisticsQuery {
query_set_id: id::QuerySetId,
query_index: u32,
},
EndPipelineStatisticsQuery,
}
impl ComputeCommand {
/// Resolves all ids in a list of commands into the corresponding resource Arc.
///
// TODO: Once resolving is done on-the-fly during recording, this function should be only needed with the replay feature:
// #[cfg(feature = "replay")]
pub fn resolve_compute_command_ids<A: HalApi>(
hub: &crate::hub::Hub<A>,
commands: &[ComputeCommand],
) -> Result<Vec<ArcComputeCommand<A>>, ComputePassError> {
let buffers_guard = hub.buffers.read();
let bind_group_guard = hub.bind_groups.read();
let query_set_guard = hub.query_sets.read();
let pipelines_guard = hub.compute_pipelines.read();
let resolved_commands: Vec<ArcComputeCommand<A>> = commands
.iter()
.map(|c| -> Result<ArcComputeCommand<A>, ComputePassError> {
Ok(match *c {
ComputeCommand::SetBindGroup {
index,
num_dynamic_offsets,
bind_group_id,
} => ArcComputeCommand::SetBindGroup {
index,
num_dynamic_offsets,
bind_group: bind_group_guard.get_owned(bind_group_id).map_err(|_| {
ComputePassError {
scope: PassErrorScope::SetBindGroup(bind_group_id),
inner: ComputePassErrorInner::InvalidBindGroup(index),
}
})?,
},
ComputeCommand::SetPipeline(pipeline_id) => ArcComputeCommand::SetPipeline(
pipelines_guard
.get_owned(pipeline_id)
.map_err(|_| ComputePassError {
scope: PassErrorScope::SetPipelineCompute(pipeline_id),
inner: ComputePassErrorInner::InvalidPipeline(pipeline_id),
})?,
),
ComputeCommand::SetPushConstant {
offset,
size_bytes,
values_offset,
} => ArcComputeCommand::SetPushConstant {
offset,
size_bytes,
values_offset,
},
ComputeCommand::Dispatch(dim) => ArcComputeCommand::Dispatch(dim),
ComputeCommand::DispatchIndirect { buffer_id, offset } => {
ArcComputeCommand::DispatchIndirect {
buffer: buffers_guard.get_owned(buffer_id).map_err(|_| {
ComputePassError {
scope: PassErrorScope::Dispatch {
indirect: true,
pipeline: None, // TODO: not used right now, but once we do the resolve during recording we can use this again.
},
inner: ComputePassErrorInner::InvalidBuffer(buffer_id),
}
})?,
offset,
}
}
ComputeCommand::PushDebugGroup { color, len } => {
ArcComputeCommand::PushDebugGroup { color, len }
}
ComputeCommand::PopDebugGroup => ArcComputeCommand::PopDebugGroup,
ComputeCommand::InsertDebugMarker { color, len } => {
ArcComputeCommand::InsertDebugMarker { color, len }
}
ComputeCommand::WriteTimestamp {
query_set_id,
query_index,
} => ArcComputeCommand::WriteTimestamp {
query_set: query_set_guard.get_owned(query_set_id).map_err(|_| {
ComputePassError {
scope: PassErrorScope::WriteTimestamp,
inner: ComputePassErrorInner::InvalidQuerySet(query_set_id),
}
})?,
query_index,
},
ComputeCommand::BeginPipelineStatisticsQuery {
query_set_id,
query_index,
} => ArcComputeCommand::BeginPipelineStatisticsQuery {
query_set: query_set_guard.get_owned(query_set_id).map_err(|_| {
ComputePassError {
scope: PassErrorScope::BeginPipelineStatisticsQuery,
inner: ComputePassErrorInner::InvalidQuerySet(query_set_id),
}
})?,
query_index,
},
ComputeCommand::EndPipelineStatisticsQuery => {
ArcComputeCommand::EndPipelineStatisticsQuery
}
})
})
.collect::<Result<Vec<_>, ComputePassError>>()?;
Ok(resolved_commands)
}
}
/// Equivalent to `ComputeCommand` but the Ids resolved into resource Arcs.
#[derive(Clone, Debug)]
pub enum ArcComputeCommand<A: HalApi> {
SetBindGroup {
index: u32,
num_dynamic_offsets: usize,
bind_group: Arc<BindGroup<A>>,
},
SetPipeline(Arc<ComputePipeline<A>>),
/// Set a range of push constants to values stored in `push_constant_data`.
SetPushConstant {
/// The byte offset within the push constant storage to write to. This
/// must be a multiple of four.
offset: u32,
/// The number of bytes to write. This must be a multiple of four.
size_bytes: u32,
/// Index in `push_constant_data` of the start of the data
/// to be written.
///
/// Note: this is not a byte offset like `offset`. Rather, it is the
/// index of the first `u32` element in `push_constant_data` to read.
values_offset: u32,
},
Dispatch([u32; 3]),
DispatchIndirect {
buffer: Arc<Buffer<A>>,
offset: wgt::BufferAddress,
},
PushDebugGroup {
color: u32,
len: usize,
},
PopDebugGroup,
InsertDebugMarker {
color: u32,
len: usize,
},
WriteTimestamp {
query_set: Arc<QuerySet<A>>,
query_index: u32,
},
BeginPipelineStatisticsQuery {
query_set: Arc<QuerySet<A>>,
query_index: u32,
},
EndPipelineStatisticsQuery,
}
#[cfg(feature = "trace")]
impl<A: HalApi> From<&ArcComputeCommand<A>> for ComputeCommand {
fn from(value: &ArcComputeCommand<A>) -> Self {
use crate::resource::Resource as _;
match value {
ArcComputeCommand::SetBindGroup {
index,
num_dynamic_offsets,
bind_group,
} => ComputeCommand::SetBindGroup {
index: *index,
num_dynamic_offsets: *num_dynamic_offsets,
bind_group_id: bind_group.as_info().id(),
},
ArcComputeCommand::SetPipeline(pipeline) => {
ComputeCommand::SetPipeline(pipeline.as_info().id())
}
ArcComputeCommand::SetPushConstant {
offset,
size_bytes,
values_offset,
} => ComputeCommand::SetPushConstant {
offset: *offset,
size_bytes: *size_bytes,
values_offset: *values_offset,
},
ArcComputeCommand::Dispatch(dim) => ComputeCommand::Dispatch(*dim),
ArcComputeCommand::DispatchIndirect { buffer, offset } => {
ComputeCommand::DispatchIndirect {
buffer_id: buffer.as_info().id(),
offset: *offset,
}
}
ArcComputeCommand::PushDebugGroup { color, len } => ComputeCommand::PushDebugGroup {
color: *color,
len: *len,
},
ArcComputeCommand::PopDebugGroup => ComputeCommand::PopDebugGroup,
ArcComputeCommand::InsertDebugMarker { color, len } => {
ComputeCommand::InsertDebugMarker {
color: *color,
len: *len,
}
}
ArcComputeCommand::WriteTimestamp {
query_set,
query_index,
} => ComputeCommand::WriteTimestamp {
query_set_id: query_set.as_info().id(),
query_index: *query_index,
},
ArcComputeCommand::BeginPipelineStatisticsQuery {
query_set,
query_index,
} => ComputeCommand::BeginPipelineStatisticsQuery {
query_set_id: query_set.as_info().id(),
query_index: *query_index,
},
ArcComputeCommand::EndPipelineStatisticsQuery => {
ComputeCommand::EndPipelineStatisticsQuery
}
}
}
}