Stand-alone compute passes

This commit is contained in:
Dzmitry Malyshau
2020-01-07 17:31:11 -05:00
parent c0fa61a064
commit 7808a4d4cd
5 changed files with 262 additions and 14 deletions

View File

@@ -6,24 +6,48 @@ use crate::{
command::{
bind::{Binder, LayoutChange},
CommandBuffer,
OffsetIndex,
},
device::{all_buffer_stages, BIND_BUFFER_ALIGNMENT},
hub::{GfxBackend, Global, IdentityFilter, Token},
id::{BindGroupId, BufferId, CommandBufferId, ComputePassId, ComputePipelineId},
id,
resource::BufferUsage,
track::TrackerSet,
BufferAddress,
Stored,
};
use hal::{self, command::CommandBuffer as _};
use hal::command::CommandBuffer as _;
use std::iter;
use std::{iter, ops::Range};
#[non_exhaustive]
#[derive(Debug)]
pub enum ComputeCommand {
SetBindGroup {
index: u32,
bind_group_id: id::BindGroupId,
offset_indices: Range<OffsetIndex>,
},
SetPipeline(id::ComputePipelineId),
Dispatch([u32; 3]),
DispatchIndirect {
buffer_id: id::BufferId,
offset: BufferAddress,
},
}
#[derive(Copy, Clone, Debug)]
pub struct StandaloneComputePass<'a> {
pub commands: &'a [ComputeCommand],
pub offsets: &'a [BufferAddress],
}
#[derive(Debug)]
pub struct ComputePass<B: hal::Backend> {
raw: B::CommandBuffer,
cmb_id: Stored<CommandBufferId>,
cmb_id: Stored<id::CommandBufferId>,
binder: Binder,
trackers: TrackerSet,
}
@@ -31,7 +55,7 @@ pub struct ComputePass<B: hal::Backend> {
impl<B: hal::Backend> ComputePass<B> {
pub(crate) fn new(
raw: B::CommandBuffer,
cmb_id: Stored<CommandBufferId>,
cmb_id: Stored<id::CommandBufferId>,
trackers: TrackerSet,
max_bind_groups: u32,
) -> Self {
@@ -46,8 +70,8 @@ impl<B: hal::Backend> ComputePass<B> {
// Common routines between render/compute
impl<F: IdentityFilter<ComputePassId>> Global<F> {
pub fn compute_pass_end_pass<B: GfxBackend>(&self, pass_id: ComputePassId) {
impl<F: IdentityFilter<id::ComputePassId>> Global<F> {
pub fn compute_pass_end_pass<B: GfxBackend>(&self, pass_id: id::ComputePassId) {
let mut token = Token::root();
let hub = B::hub(self);
let (mut cmb_guard, mut token) = hub.command_buffers.write(&mut token);
@@ -63,11 +87,154 @@ impl<F: IdentityFilter<ComputePassId>> Global<F> {
}
impl<F> Global<F> {
pub fn command_encoder_run_compute_pass<B: GfxBackend>(
&self,
encoder_id: id::CommandEncoderId,
pass: StandaloneComputePass,
) {
let hub = B::hub(self);
let mut token = Token::root();
let (mut cmb_guard, mut token) = hub.command_buffers.write(&mut token);
let cmb = &mut cmb_guard[encoder_id];
let raw = cmb.raw.last_mut().unwrap();
let mut binder = Binder::new(cmb.features.max_bind_groups);
let (pipeline_layout_guard, mut token) = hub.pipeline_layouts.read(&mut token);
let (bind_group_guard, mut token) = hub.bind_groups.read(&mut token);
let (pipeline_guard, mut token) = hub.compute_pipelines.read(&mut token);
let (buffer_guard, mut token) = hub.buffers.read(&mut token);
let (texture_guard, _) = hub.textures.read(&mut token);
for command in pass.commands {
match *command {
ComputeCommand::SetBindGroup { index, bind_group_id, ref offset_indices } => {
let offsets = &pass.offsets[offset_indices.start as usize .. offset_indices.end as usize];
if cfg!(debug_assertions) {
for off in offsets {
assert_eq!(
*off % BIND_BUFFER_ALIGNMENT,
0,
"Misaligned dynamic buffer offset: {} does not align with {}",
off,
BIND_BUFFER_ALIGNMENT
);
}
}
let bind_group = cmb
.trackers
.bind_groups
.use_extend(&*bind_group_guard, bind_group_id, (), ())
.unwrap();
assert_eq!(bind_group.dynamic_count, offsets.len());
log::trace!(
"Encoding barriers on binding of {:?} to {:?}",
bind_group_id,
encoder_id
);
CommandBuffer::insert_barriers(
raw,
&mut cmb.trackers,
&bind_group.used,
&*buffer_guard,
&*texture_guard,
);
if let Some((pipeline_layout_id, follow_ups)) = binder
.provide_entry(index as usize, bind_group_id, bind_group, offsets)
{
let bind_groups = iter::once(bind_group.raw.raw())
.chain(follow_ups.clone().map(|(bg_id, _)| bind_group_guard[bg_id].raw.raw()));
unsafe {
raw.bind_compute_descriptor_sets(
&pipeline_layout_guard[pipeline_layout_id].raw,
index as usize,
bind_groups,
offsets
.iter()
.chain(follow_ups.flat_map(|(_, offsets)| offsets))
.map(|&off| off as hal::command::DescriptorSetOffset),
);
}
}
}
ComputeCommand::SetPipeline(pipeline_id) => {
let pipeline = &pipeline_guard[pipeline_id];
unsafe {
raw.bind_compute_pipeline(&pipeline.raw);
}
// Rebind resources
if binder.pipeline_layout_id != Some(pipeline.layout_id) {
let pipeline_layout = &pipeline_layout_guard[pipeline.layout_id];
binder.pipeline_layout_id = Some(pipeline.layout_id);
binder
.reset_expectations(pipeline_layout.bind_group_layout_ids.len());
let mut is_compatible = true;
for (index, (entry, &bgl_id)) in binder
.entries
.iter_mut()
.zip(&pipeline_layout.bind_group_layout_ids)
.enumerate()
{
match entry.expect_layout(bgl_id) {
LayoutChange::Match(bg_id, offsets) if is_compatible => {
let desc_set = bind_group_guard[bg_id].raw.raw();
unsafe {
raw.bind_compute_descriptor_sets(
&pipeline_layout.raw,
index,
iter::once(desc_set),
offsets.iter().map(|offset| *offset as u32),
);
}
}
LayoutChange::Match(..) | LayoutChange::Unchanged => {}
LayoutChange::Mismatch => {
is_compatible = false;
}
}
}
}
}
ComputeCommand::Dispatch(groups) => {
unsafe {
raw.dispatch(groups);
}
}
ComputeCommand::DispatchIndirect { buffer_id, offset } => {
let (src_buffer, src_pending) = cmb.trackers.buffers.use_replace(
&*buffer_guard,
buffer_id,
(),
BufferUsage::INDIRECT,
);
assert!(src_buffer.usage.contains(BufferUsage::INDIRECT));
let barriers = src_pending.map(|pending| pending.into_hal(src_buffer));
unsafe {
raw.pipeline_barrier(
all_buffer_stages() .. all_buffer_stages(),
hal::memory::Dependencies::empty(),
barriers,
);
raw.dispatch_indirect(&src_buffer.raw, offset);
}
}
}
}
}
pub fn compute_pass_set_bind_group<B: GfxBackend>(
&self,
pass_id: ComputePassId,
pass_id: id::ComputePassId,
index: u32,
bind_group_id: BindGroupId,
bind_group_id: id::BindGroupId,
offsets: &[BufferAddress],
) {
let hub = B::hub(self);
@@ -140,7 +307,7 @@ impl<F> Global<F> {
pub fn compute_pass_dispatch<B: GfxBackend>(
&self,
pass_id: ComputePassId,
pass_id: id::ComputePassId,
x: u32,
y: u32,
z: u32,
@@ -155,8 +322,8 @@ impl<F> Global<F> {
pub fn compute_pass_dispatch_indirect<B: GfxBackend>(
&self,
pass_id: ComputePassId,
indirect_buffer_id: BufferId,
pass_id: id::ComputePassId,
indirect_buffer_id: id::BufferId,
indirect_offset: BufferAddress,
) {
let hub = B::hub(self);
@@ -188,8 +355,8 @@ impl<F> Global<F> {
pub fn compute_pass_set_pipeline<B: GfxBackend>(
&self,
pass_id: ComputePassId,
pipeline_id: ComputePipelineId,
pass_id: id::ComputePassId,
pipeline_id: id::ComputePipelineId,
) {
let hub = B::hub(self);
let mut token = Token::root();

View File

@@ -55,6 +55,8 @@ use std::{
};
pub type OffsetIndex = u16;
pub struct RenderBundle<B: hal::Backend> {
_raw: B::CommandBuffer,
}

View File

@@ -170,6 +170,7 @@ impl<B: hal::Backend> Access<Device<B>> for Adapter<B> {}
impl<B: hal::Backend> Access<SwapChain<B>> for Device<B> {}
impl<B: hal::Backend> Access<PipelineLayout<B>> for Root {}
impl<B: hal::Backend> Access<PipelineLayout<B>> for Device<B> {}
impl<B: hal::Backend> Access<PipelineLayout<B>> for CommandBuffer<B> {}
impl<B: hal::Backend> Access<BindGroupLayout<B>> for Root {}
impl<B: hal::Backend> Access<BindGroupLayout<B>> for Device<B> {}
impl<B: hal::Backend> Access<BindGroup<B>> for Root {}
@@ -187,8 +188,10 @@ impl<B: hal::Backend> Access<RenderPass<B>> for Root {}
impl<B: hal::Backend> Access<RenderPass<B>> for BindGroup<B> {}
impl<B: hal::Backend> Access<RenderPass<B>> for CommandBuffer<B> {}
impl<B: hal::Backend> Access<ComputePipeline<B>> for Root {}
impl<B: hal::Backend> Access<ComputePipeline<B>> for BindGroup<B> {}
impl<B: hal::Backend> Access<ComputePipeline<B>> for ComputePass<B> {}
impl<B: hal::Backend> Access<RenderPipeline<B>> for Root {}
impl<B: hal::Backend> Access<RenderPipeline<B>> for BindGroup<B> {}
impl<B: hal::Backend> Access<RenderPipeline<B>> for RenderPass<B> {}
impl<B: hal::Backend> Access<ShaderModule<B>> for Root {}
impl<B: hal::Backend> Access<ShaderModule<B>> for PipelineLayout<B> {}

View File

@@ -20,6 +20,7 @@ struct IdentityHub {
adapters: IdentityManager,
devices: IdentityManager,
buffers: IdentityManager,
command_buffers: IdentityManager,
}
#[derive(Debug, Default)]
@@ -153,3 +154,30 @@ pub extern "C" fn wgpu_client_kill_buffer_id(client: &Client, id: id::BufferId)
.buffers
.free(id)
}
#[no_mangle]
pub extern "C" fn wgpu_client_make_encoder_id(
client: &Client,
device_id: id::DeviceId,
) -> id::CommandEncoderId {
let backend = device_id.backend();
client
.identities
.lock()
.select(backend)
.command_buffers
.alloc(backend)
}
#[no_mangle]
pub extern "C" fn wgpu_client_kill_encoder_id(
client: &Client,
id: id::CommandEncoderId,
) {
client
.identities
.lock()
.select(id.backend())
.command_buffers
.free(id)
}

View File

@@ -101,3 +101,51 @@ pub extern "C" fn wgpu_server_device_get_buffer_sub_data(
pub extern "C" fn wgpu_server_buffer_destroy(global: &Global, self_id: id::BufferId) {
gfx_select!(self_id => global.buffer_destroy(self_id));
}
#[no_mangle]
pub extern "C" fn wgpu_server_device_create_encoder(
global: &Global,
self_id: id::DeviceId,
encoder_id: id::CommandEncoderId,
) {
let desc = core::command::CommandEncoderDescriptor {
todo: 0,
};
gfx_select!(self_id => global.device_create_command_encoder(self_id, &desc, encoder_id));
}
#[no_mangle]
pub extern "C" fn wgpu_server_encoder_destroy(
_global: &Global,
_self_id: id::CommandEncoderId,
) {
//TODO
//gfx_select!(self_id => global.command_encoder_destroy(self_id));
}
#[no_mangle]
pub unsafe extern "C" fn wgpu_server_encode_compute_pass(
global: &Global,
self_id: id::CommandEncoderId,
commands: *const core::command::ComputeCommand,
command_length: usize,
offsets: *const core::BufferAddress,
offset_length: usize,
) {
let pass = core::command::StandaloneComputePass {
commands: slice::from_raw_parts(commands, command_length),
offsets: slice::from_raw_parts(offsets, offset_length),
};
gfx_select!(self_id => global.command_encoder_run_compute_pass(self_id, pass));
}
#[no_mangle]
pub unsafe extern "C" fn wgpu_server_queue_submit(
global: &Global,
self_id: id::QueueId,
command_buffer_ids: *const id::CommandBufferId,
command_buffer_id_length: usize,
) {
let command_buffers = slice::from_raw_parts(command_buffer_ids, command_buffer_id_length);
gfx_select!(self_id => global.queue_submit(self_id, command_buffers));
}