mirror of
https://github.com/gfx-rs/wgpu.git
synced 2026-04-22 03:02:01 -04:00
Stand-alone compute passes
This commit is contained in:
@@ -6,24 +6,48 @@ use crate::{
|
||||
command::{
|
||||
bind::{Binder, LayoutChange},
|
||||
CommandBuffer,
|
||||
OffsetIndex,
|
||||
},
|
||||
device::{all_buffer_stages, BIND_BUFFER_ALIGNMENT},
|
||||
hub::{GfxBackend, Global, IdentityFilter, Token},
|
||||
id::{BindGroupId, BufferId, CommandBufferId, ComputePassId, ComputePipelineId},
|
||||
id,
|
||||
resource::BufferUsage,
|
||||
track::TrackerSet,
|
||||
BufferAddress,
|
||||
Stored,
|
||||
};
|
||||
|
||||
use hal::{self, command::CommandBuffer as _};
|
||||
use hal::command::CommandBuffer as _;
|
||||
|
||||
use std::iter;
|
||||
use std::{iter, ops::Range};
|
||||
|
||||
|
||||
#[non_exhaustive]
|
||||
#[derive(Debug)]
|
||||
pub enum ComputeCommand {
|
||||
SetBindGroup {
|
||||
index: u32,
|
||||
bind_group_id: id::BindGroupId,
|
||||
offset_indices: Range<OffsetIndex>,
|
||||
},
|
||||
SetPipeline(id::ComputePipelineId),
|
||||
Dispatch([u32; 3]),
|
||||
DispatchIndirect {
|
||||
buffer_id: id::BufferId,
|
||||
offset: BufferAddress,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
pub struct StandaloneComputePass<'a> {
|
||||
pub commands: &'a [ComputeCommand],
|
||||
pub offsets: &'a [BufferAddress],
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ComputePass<B: hal::Backend> {
|
||||
raw: B::CommandBuffer,
|
||||
cmb_id: Stored<CommandBufferId>,
|
||||
cmb_id: Stored<id::CommandBufferId>,
|
||||
binder: Binder,
|
||||
trackers: TrackerSet,
|
||||
}
|
||||
@@ -31,7 +55,7 @@ pub struct ComputePass<B: hal::Backend> {
|
||||
impl<B: hal::Backend> ComputePass<B> {
|
||||
pub(crate) fn new(
|
||||
raw: B::CommandBuffer,
|
||||
cmb_id: Stored<CommandBufferId>,
|
||||
cmb_id: Stored<id::CommandBufferId>,
|
||||
trackers: TrackerSet,
|
||||
max_bind_groups: u32,
|
||||
) -> Self {
|
||||
@@ -46,8 +70,8 @@ impl<B: hal::Backend> ComputePass<B> {
|
||||
|
||||
// Common routines between render/compute
|
||||
|
||||
impl<F: IdentityFilter<ComputePassId>> Global<F> {
|
||||
pub fn compute_pass_end_pass<B: GfxBackend>(&self, pass_id: ComputePassId) {
|
||||
impl<F: IdentityFilter<id::ComputePassId>> Global<F> {
|
||||
pub fn compute_pass_end_pass<B: GfxBackend>(&self, pass_id: id::ComputePassId) {
|
||||
let mut token = Token::root();
|
||||
let hub = B::hub(self);
|
||||
let (mut cmb_guard, mut token) = hub.command_buffers.write(&mut token);
|
||||
@@ -63,11 +87,154 @@ impl<F: IdentityFilter<ComputePassId>> Global<F> {
|
||||
}
|
||||
|
||||
impl<F> Global<F> {
|
||||
pub fn command_encoder_run_compute_pass<B: GfxBackend>(
|
||||
&self,
|
||||
encoder_id: id::CommandEncoderId,
|
||||
pass: StandaloneComputePass,
|
||||
) {
|
||||
let hub = B::hub(self);
|
||||
let mut token = Token::root();
|
||||
|
||||
let (mut cmb_guard, mut token) = hub.command_buffers.write(&mut token);
|
||||
let cmb = &mut cmb_guard[encoder_id];
|
||||
let raw = cmb.raw.last_mut().unwrap();
|
||||
let mut binder = Binder::new(cmb.features.max_bind_groups);
|
||||
|
||||
let (pipeline_layout_guard, mut token) = hub.pipeline_layouts.read(&mut token);
|
||||
let (bind_group_guard, mut token) = hub.bind_groups.read(&mut token);
|
||||
let (pipeline_guard, mut token) = hub.compute_pipelines.read(&mut token);
|
||||
let (buffer_guard, mut token) = hub.buffers.read(&mut token);
|
||||
let (texture_guard, _) = hub.textures.read(&mut token);
|
||||
|
||||
for command in pass.commands {
|
||||
match *command {
|
||||
ComputeCommand::SetBindGroup { index, bind_group_id, ref offset_indices } => {
|
||||
let offsets = &pass.offsets[offset_indices.start as usize .. offset_indices.end as usize];
|
||||
if cfg!(debug_assertions) {
|
||||
for off in offsets {
|
||||
assert_eq!(
|
||||
*off % BIND_BUFFER_ALIGNMENT,
|
||||
0,
|
||||
"Misaligned dynamic buffer offset: {} does not align with {}",
|
||||
off,
|
||||
BIND_BUFFER_ALIGNMENT
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
let bind_group = cmb
|
||||
.trackers
|
||||
.bind_groups
|
||||
.use_extend(&*bind_group_guard, bind_group_id, (), ())
|
||||
.unwrap();
|
||||
assert_eq!(bind_group.dynamic_count, offsets.len());
|
||||
|
||||
log::trace!(
|
||||
"Encoding barriers on binding of {:?} to {:?}",
|
||||
bind_group_id,
|
||||
encoder_id
|
||||
);
|
||||
CommandBuffer::insert_barriers(
|
||||
raw,
|
||||
&mut cmb.trackers,
|
||||
&bind_group.used,
|
||||
&*buffer_guard,
|
||||
&*texture_guard,
|
||||
);
|
||||
|
||||
if let Some((pipeline_layout_id, follow_ups)) = binder
|
||||
.provide_entry(index as usize, bind_group_id, bind_group, offsets)
|
||||
{
|
||||
let bind_groups = iter::once(bind_group.raw.raw())
|
||||
.chain(follow_ups.clone().map(|(bg_id, _)| bind_group_guard[bg_id].raw.raw()));
|
||||
unsafe {
|
||||
raw.bind_compute_descriptor_sets(
|
||||
&pipeline_layout_guard[pipeline_layout_id].raw,
|
||||
index as usize,
|
||||
bind_groups,
|
||||
offsets
|
||||
.iter()
|
||||
.chain(follow_ups.flat_map(|(_, offsets)| offsets))
|
||||
.map(|&off| off as hal::command::DescriptorSetOffset),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
ComputeCommand::SetPipeline(pipeline_id) => {
|
||||
let pipeline = &pipeline_guard[pipeline_id];
|
||||
|
||||
unsafe {
|
||||
raw.bind_compute_pipeline(&pipeline.raw);
|
||||
}
|
||||
|
||||
// Rebind resources
|
||||
if binder.pipeline_layout_id != Some(pipeline.layout_id) {
|
||||
let pipeline_layout = &pipeline_layout_guard[pipeline.layout_id];
|
||||
binder.pipeline_layout_id = Some(pipeline.layout_id);
|
||||
binder
|
||||
.reset_expectations(pipeline_layout.bind_group_layout_ids.len());
|
||||
let mut is_compatible = true;
|
||||
|
||||
for (index, (entry, &bgl_id)) in binder
|
||||
.entries
|
||||
.iter_mut()
|
||||
.zip(&pipeline_layout.bind_group_layout_ids)
|
||||
.enumerate()
|
||||
{
|
||||
match entry.expect_layout(bgl_id) {
|
||||
LayoutChange::Match(bg_id, offsets) if is_compatible => {
|
||||
let desc_set = bind_group_guard[bg_id].raw.raw();
|
||||
unsafe {
|
||||
raw.bind_compute_descriptor_sets(
|
||||
&pipeline_layout.raw,
|
||||
index,
|
||||
iter::once(desc_set),
|
||||
offsets.iter().map(|offset| *offset as u32),
|
||||
);
|
||||
}
|
||||
}
|
||||
LayoutChange::Match(..) | LayoutChange::Unchanged => {}
|
||||
LayoutChange::Mismatch => {
|
||||
is_compatible = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
ComputeCommand::Dispatch(groups) => {
|
||||
unsafe {
|
||||
raw.dispatch(groups);
|
||||
}
|
||||
}
|
||||
ComputeCommand::DispatchIndirect { buffer_id, offset } => {
|
||||
let (src_buffer, src_pending) = cmb.trackers.buffers.use_replace(
|
||||
&*buffer_guard,
|
||||
buffer_id,
|
||||
(),
|
||||
BufferUsage::INDIRECT,
|
||||
);
|
||||
assert!(src_buffer.usage.contains(BufferUsage::INDIRECT));
|
||||
|
||||
let barriers = src_pending.map(|pending| pending.into_hal(src_buffer));
|
||||
|
||||
unsafe {
|
||||
raw.pipeline_barrier(
|
||||
all_buffer_stages() .. all_buffer_stages(),
|
||||
hal::memory::Dependencies::empty(),
|
||||
barriers,
|
||||
);
|
||||
raw.dispatch_indirect(&src_buffer.raw, offset);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn compute_pass_set_bind_group<B: GfxBackend>(
|
||||
&self,
|
||||
pass_id: ComputePassId,
|
||||
pass_id: id::ComputePassId,
|
||||
index: u32,
|
||||
bind_group_id: BindGroupId,
|
||||
bind_group_id: id::BindGroupId,
|
||||
offsets: &[BufferAddress],
|
||||
) {
|
||||
let hub = B::hub(self);
|
||||
@@ -140,7 +307,7 @@ impl<F> Global<F> {
|
||||
|
||||
pub fn compute_pass_dispatch<B: GfxBackend>(
|
||||
&self,
|
||||
pass_id: ComputePassId,
|
||||
pass_id: id::ComputePassId,
|
||||
x: u32,
|
||||
y: u32,
|
||||
z: u32,
|
||||
@@ -155,8 +322,8 @@ impl<F> Global<F> {
|
||||
|
||||
pub fn compute_pass_dispatch_indirect<B: GfxBackend>(
|
||||
&self,
|
||||
pass_id: ComputePassId,
|
||||
indirect_buffer_id: BufferId,
|
||||
pass_id: id::ComputePassId,
|
||||
indirect_buffer_id: id::BufferId,
|
||||
indirect_offset: BufferAddress,
|
||||
) {
|
||||
let hub = B::hub(self);
|
||||
@@ -188,8 +355,8 @@ impl<F> Global<F> {
|
||||
|
||||
pub fn compute_pass_set_pipeline<B: GfxBackend>(
|
||||
&self,
|
||||
pass_id: ComputePassId,
|
||||
pipeline_id: ComputePipelineId,
|
||||
pass_id: id::ComputePassId,
|
||||
pipeline_id: id::ComputePipelineId,
|
||||
) {
|
||||
let hub = B::hub(self);
|
||||
let mut token = Token::root();
|
||||
|
||||
@@ -55,6 +55,8 @@ use std::{
|
||||
};
|
||||
|
||||
|
||||
pub type OffsetIndex = u16;
|
||||
|
||||
pub struct RenderBundle<B: hal::Backend> {
|
||||
_raw: B::CommandBuffer,
|
||||
}
|
||||
|
||||
@@ -170,6 +170,7 @@ impl<B: hal::Backend> Access<Device<B>> for Adapter<B> {}
|
||||
impl<B: hal::Backend> Access<SwapChain<B>> for Device<B> {}
|
||||
impl<B: hal::Backend> Access<PipelineLayout<B>> for Root {}
|
||||
impl<B: hal::Backend> Access<PipelineLayout<B>> for Device<B> {}
|
||||
impl<B: hal::Backend> Access<PipelineLayout<B>> for CommandBuffer<B> {}
|
||||
impl<B: hal::Backend> Access<BindGroupLayout<B>> for Root {}
|
||||
impl<B: hal::Backend> Access<BindGroupLayout<B>> for Device<B> {}
|
||||
impl<B: hal::Backend> Access<BindGroup<B>> for Root {}
|
||||
@@ -187,8 +188,10 @@ impl<B: hal::Backend> Access<RenderPass<B>> for Root {}
|
||||
impl<B: hal::Backend> Access<RenderPass<B>> for BindGroup<B> {}
|
||||
impl<B: hal::Backend> Access<RenderPass<B>> for CommandBuffer<B> {}
|
||||
impl<B: hal::Backend> Access<ComputePipeline<B>> for Root {}
|
||||
impl<B: hal::Backend> Access<ComputePipeline<B>> for BindGroup<B> {}
|
||||
impl<B: hal::Backend> Access<ComputePipeline<B>> for ComputePass<B> {}
|
||||
impl<B: hal::Backend> Access<RenderPipeline<B>> for Root {}
|
||||
impl<B: hal::Backend> Access<RenderPipeline<B>> for BindGroup<B> {}
|
||||
impl<B: hal::Backend> Access<RenderPipeline<B>> for RenderPass<B> {}
|
||||
impl<B: hal::Backend> Access<ShaderModule<B>> for Root {}
|
||||
impl<B: hal::Backend> Access<ShaderModule<B>> for PipelineLayout<B> {}
|
||||
|
||||
@@ -20,6 +20,7 @@ struct IdentityHub {
|
||||
adapters: IdentityManager,
|
||||
devices: IdentityManager,
|
||||
buffers: IdentityManager,
|
||||
command_buffers: IdentityManager,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
@@ -153,3 +154,30 @@ pub extern "C" fn wgpu_client_kill_buffer_id(client: &Client, id: id::BufferId)
|
||||
.buffers
|
||||
.free(id)
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub extern "C" fn wgpu_client_make_encoder_id(
|
||||
client: &Client,
|
||||
device_id: id::DeviceId,
|
||||
) -> id::CommandEncoderId {
|
||||
let backend = device_id.backend();
|
||||
client
|
||||
.identities
|
||||
.lock()
|
||||
.select(backend)
|
||||
.command_buffers
|
||||
.alloc(backend)
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub extern "C" fn wgpu_client_kill_encoder_id(
|
||||
client: &Client,
|
||||
id: id::CommandEncoderId,
|
||||
) {
|
||||
client
|
||||
.identities
|
||||
.lock()
|
||||
.select(id.backend())
|
||||
.command_buffers
|
||||
.free(id)
|
||||
}
|
||||
|
||||
@@ -101,3 +101,51 @@ pub extern "C" fn wgpu_server_device_get_buffer_sub_data(
|
||||
pub extern "C" fn wgpu_server_buffer_destroy(global: &Global, self_id: id::BufferId) {
|
||||
gfx_select!(self_id => global.buffer_destroy(self_id));
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub extern "C" fn wgpu_server_device_create_encoder(
|
||||
global: &Global,
|
||||
self_id: id::DeviceId,
|
||||
encoder_id: id::CommandEncoderId,
|
||||
) {
|
||||
let desc = core::command::CommandEncoderDescriptor {
|
||||
todo: 0,
|
||||
};
|
||||
gfx_select!(self_id => global.device_create_command_encoder(self_id, &desc, encoder_id));
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub extern "C" fn wgpu_server_encoder_destroy(
|
||||
_global: &Global,
|
||||
_self_id: id::CommandEncoderId,
|
||||
) {
|
||||
//TODO
|
||||
//gfx_select!(self_id => global.command_encoder_destroy(self_id));
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn wgpu_server_encode_compute_pass(
|
||||
global: &Global,
|
||||
self_id: id::CommandEncoderId,
|
||||
commands: *const core::command::ComputeCommand,
|
||||
command_length: usize,
|
||||
offsets: *const core::BufferAddress,
|
||||
offset_length: usize,
|
||||
) {
|
||||
let pass = core::command::StandaloneComputePass {
|
||||
commands: slice::from_raw_parts(commands, command_length),
|
||||
offsets: slice::from_raw_parts(offsets, offset_length),
|
||||
};
|
||||
gfx_select!(self_id => global.command_encoder_run_compute_pass(self_id, pass));
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn wgpu_server_queue_submit(
|
||||
global: &Global,
|
||||
self_id: id::QueueId,
|
||||
command_buffer_ids: *const id::CommandBufferId,
|
||||
command_buffer_id_length: usize,
|
||||
) {
|
||||
let command_buffers = slice::from_raw_parts(command_buffer_ids, command_buffer_id_length);
|
||||
gfx_select!(self_id => global.queue_submit(self_id, command_buffers));
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user