diff --git a/wgpu-core/src/device/life.rs b/wgpu-core/src/device/life.rs index 63270106af..017cb51544 100644 --- a/wgpu-core/src/device/life.rs +++ b/wgpu-core/src/device/life.rs @@ -2,7 +2,7 @@ use crate::device::trace; use crate::{ device::{ - queue::{EncoderInFlight, TempResource}, + queue::{EncoderInFlight, SubmittedWorkDoneClosure, TempResource}, DeviceError, }, hub::{GlobalIdentityHandlerFactory, HalApi, Hub, Token}, @@ -10,6 +10,7 @@ use crate::{ track::TrackerSet, RefCount, Stored, SubmissionIndex, }; +use smallvec::SmallVec; use copyless::VecHelper as _; use hal::Device as _; @@ -165,6 +166,7 @@ struct ActiveSubmission { last_resources: NonReferencedResources, mapped: Vec>, encoders: Vec>, + work_done_closures: SmallVec<[SubmittedWorkDoneClosure; 1]>, } #[derive(Clone, Debug, Error)] @@ -235,6 +237,7 @@ impl LifetimeTracker { last_resources, mapped: Vec::new(), encoders, + work_done_closures: SmallVec::new(), }); } @@ -256,11 +259,12 @@ impl LifetimeTracker { } /// Returns the last submission index that is done. + #[must_use] pub fn triage_submissions( &mut self, last_done: SubmissionIndex, command_allocator: &Mutex>, - ) { + ) -> SmallVec<[SubmittedWorkDoneClosure; 1]> { profiling::scope!("triage_submissions"); //TODO: enable when `is_sorted_by_key` is stable @@ -271,6 +275,7 @@ impl LifetimeTracker { .position(|a| a.index > last_done) .unwrap_or_else(|| self.active.len()); + let mut work_done_closures = SmallVec::new(); for a in self.active.drain(..done_count) { log::trace!("Active submission {} is done", a.index); self.free_resources.extend(a.last_resources); @@ -279,7 +284,9 @@ impl LifetimeTracker { let raw = unsafe { encoder.land() }; command_allocator.lock().release_encoder(raw); } + work_done_closures.extend(a.work_done_closures); } + work_done_closures } pub fn cleanup(&mut self, device: &A::Device) { @@ -304,6 +311,18 @@ impl LifetimeTracker { TempResource::Texture(raw) => resources.textures.push(raw), } } + + pub fn add_work_done_closure(&mut self, closure: SubmittedWorkDoneClosure) -> bool { + match self.active.last_mut() { + Some(active) => { + active.work_done_closures.push(closure); + true + } + // Note: we can't immediately invoke the closure, since it assumes + // nothing is currently locked in the hubs. + None => false, + } + } } impl LifetimeTracker { @@ -621,18 +640,19 @@ impl LifetimeTracker { } } + #[must_use] pub(super) fn handle_mapping( &mut self, hub: &Hub, raw: &A::Device, trackers: &Mutex, token: &mut Token>, - ) -> Vec { + ) -> Vec { if self.ready_to_map.is_empty() { return Vec::new(); } let (mut buffer_guard, _) = hub.buffers.write(token); - let mut pending_callbacks: Vec = + let mut pending_callbacks: Vec = Vec::with_capacity(self.ready_to_map.len()); let mut trackers = trackers.lock(); for buffer_id in self.ready_to_map.drain(..) { diff --git a/wgpu-core/src/device/mod.rs b/wgpu-core/src/device/mod.rs index 9f8433db9b..67ddbcae39 100644 --- a/wgpu-core/src/device/mod.rs +++ b/wgpu-core/src/device/mod.rs @@ -15,6 +15,7 @@ use arrayvec::ArrayVec; use copyless::VecHelper as _; use hal::{CommandEncoder as _, Device as _}; use parking_lot::{Mutex, MutexGuard}; +use smallvec::SmallVec; use thiserror::Error; use wgt::{BufferAddress, TextureFormat, TextureViewDimension}; @@ -109,7 +110,31 @@ impl RenderPassContext { } } -type BufferMapPendingCallback = (resource::BufferMapOperation, resource::BufferMapAsyncStatus); +pub type BufferMapPendingClosure = (resource::BufferMapOperation, resource::BufferMapAsyncStatus); + +#[derive(Default)] +pub struct UserClosures { + pub mappings: Vec, + pub submissions: SmallVec<[queue::SubmittedWorkDoneClosure; 1]>, +} + +impl UserClosures { + fn extend(&mut self, other: UserClosures) { + self.mappings.extend(other.mappings); + self.submissions.extend(other.submissions); + } + + unsafe fn fire(self) { + //Note: this logic is specifically moved out of `handle_mapping()` in order to + // have nothing locked by the time we execute users callback code. + for (operation, status) in self.mappings { + (operation.callback)(status, operation.user_data); + } + for closure in self.submissions { + (closure.callback)(closure.user_data); + } + } +} fn map_buffer( raw: &A::Device, @@ -169,14 +194,6 @@ fn map_buffer( Ok(mapping.ptr) } -//Note: this logic is specifically moved out of `handle_mapping()` in order to -// have nothing locked by the time we execute users callback code. -fn fire_map_callbacks>(callbacks: I) { - for (operation, status) in callbacks { - unsafe { (operation.callback)(status, operation.user_data) } - } -} - struct CommandAllocator { free_encoders: Vec, } @@ -356,7 +373,7 @@ impl Device { hub: &Hub, force_wait: bool, token: &mut Token<'token, Self>, - ) -> Result, WaitIdleError> { + ) -> Result { profiling::scope!("maintain", "Device"); let mut life_tracker = self.lock_life(token); @@ -389,11 +406,15 @@ impl Device { } }; - life_tracker.triage_submissions(last_done_index, &self.command_allocator); - let callbacks = life_tracker.handle_mapping(hub, &self.raw, &self.trackers, token); + let submission_closures = + life_tracker.triage_submissions(last_done_index, &self.command_allocator); + let mapping_closures = life_tracker.handle_mapping(hub, &self.raw, &self.trackers, token); life_tracker.cleanup(&self.raw); - Ok(callbacks) + Ok(UserClosures { + mappings: mapping_closures, + submissions: submission_closures, + }) } fn untrack<'this, 'token: 'this, G: GlobalIdentityHandlerFactory>( @@ -2380,8 +2401,13 @@ impl Device { .wait(&self.fence, submission_index, !0) .map_err(DeviceError::from)? }; - self.lock_life(token) + let closures = self + .lock_life(token) .triage_submissions(submission_index, &self.command_allocator); + assert!( + closures.is_empty(), + "wait_for_submit is not expected to work with closures" + ); } Ok(()) } @@ -2462,7 +2488,7 @@ impl Device { if let Err(error) = unsafe { self.raw.wait(&self.fence, current_index, CLEANUP_WAIT_MS) } { log::error!("failed to wait for the device: {:?}", error); } - life_tracker.triage_submissions(current_index, &self.command_allocator); + let _ = life_tracker.triage_submissions(current_index, &self.command_allocator); life_tracker.cleanup(&self.raw); } @@ -4433,23 +4459,25 @@ impl Global { device_id: id::DeviceId, force_wait: bool, ) -> Result<(), WaitIdleError> { - let hub = A::hub(self); - let mut token = Token::root(); - let callbacks = { + let closures = { + let hub = A::hub(self); + let mut token = Token::root(); let (device_guard, mut token) = hub.devices.read(&mut token); device_guard .get(device_id) .map_err(|_| DeviceError::Invalid)? .maintain(hub, force_wait, &mut token)? }; - fire_map_callbacks(callbacks); + unsafe { + closures.fire(); + } Ok(()) } fn poll_devices( &self, force_wait: bool, - callbacks: &mut Vec, + closures: &mut UserClosures, ) -> Result<(), WaitIdleError> { profiling::scope!("poll_devices"); @@ -4458,32 +4486,34 @@ impl Global { let (device_guard, mut token) = hub.devices.read(&mut token); for (_, device) in device_guard.iter(A::VARIANT) { let cbs = device.maintain(hub, force_wait, &mut token)?; - callbacks.extend(cbs); + closures.extend(cbs); } Ok(()) } pub fn poll_all_devices(&self, force_wait: bool) -> Result<(), WaitIdleError> { - let mut callbacks = Vec::new(); + let mut closures = UserClosures::default(); #[cfg(vulkan)] { - self.poll_devices::(force_wait, &mut callbacks)?; + self.poll_devices::(force_wait, &mut closures)?; } #[cfg(metal)] { - self.poll_devices::(force_wait, &mut callbacks)?; + self.poll_devices::(force_wait, &mut closures)?; } #[cfg(dx12)] { - self.poll_devices::(force_wait, &mut callbacks)?; + self.poll_devices::(force_wait, &mut closures)?; } #[cfg(dx11)] { - self.poll_devices::(force_wait, &mut callbacks)?; + self.poll_devices::(force_wait, &mut closures)?; } - fire_map_callbacks(callbacks); + unsafe { + closures.fire(); + } Ok(()) } @@ -4659,7 +4689,7 @@ impl Global { fn buffer_unmap_inner( &self, buffer_id: id::BufferId, - ) -> Result, resource::BufferAccessError> { + ) -> Result, resource::BufferAccessError> { profiling::scope!("unmap", "Buffer"); let hub = A::hub(self); @@ -4773,8 +4803,13 @@ impl Global { &self, buffer_id: id::BufferId, ) -> Result<(), resource::BufferAccessError> { - self.buffer_unmap_inner::(buffer_id) - //Note: outside inner function so no locks are held when calling the callback - .map(|pending_callback| fire_map_callbacks(pending_callback.into_iter())) + //Note: outside inner function so no locks are held when calling the callback + let closure = self.buffer_unmap_inner::(buffer_id)?; + if let Some((operation, status)) = closure { + unsafe { + (operation.callback)(status, operation.user_data); + } + } + Ok(()) } } diff --git a/wgpu-core/src/device/queue.rs b/wgpu-core/src/device/queue.rs index 8c25456782..4e3bc5de63 100644 --- a/wgpu-core/src/device/queue.rs +++ b/wgpu-core/src/device/queue.rs @@ -26,6 +26,17 @@ use thiserror::Error; /// without a concrete moment of when it can be cleared. const WRITE_COMMAND_BUFFERS_PER_POOL: usize = 64; +pub type OnSubmittedWorkDoneCallback = unsafe extern "C" fn(user_data: *mut u8); +#[repr(C)] +#[derive(Clone, Copy, Debug)] +pub struct SubmittedWorkDoneClosure { + pub callback: OnSubmittedWorkDoneCallback, + pub user_data: *mut u8, +} + +unsafe impl Send for SubmittedWorkDoneClosure {} +unsafe impl Sync for SubmittedWorkDoneClosure {} + struct StagingData { buffer: A::Buffer, } @@ -506,10 +517,10 @@ impl Global { ) -> Result<(), QueueSubmitError> { profiling::scope!("submit", "Queue"); - let hub = A::hub(self); - let mut token = Token::root(); - let callbacks = { + let hub = A::hub(self); + let mut token = Token::root(); + let (mut device_guard, mut token) = hub.devices.write(&mut token); let device = device_guard .get_mut(queue_id) @@ -741,8 +752,8 @@ impl Global { // This will schedule destruction of all resources that are no longer needed // by the user but used in the command stream, among other things. - let callbacks = match device.maintain(hub, false, &mut token) { - Ok(callbacks) => callbacks, + let closures = match device.maintain(hub, false, &mut token) { + Ok(closures) => closures, Err(WaitIdleError::Device(err)) => return Err(QueueSubmitError::Queue(err)), Err(WaitIdleError::StuckGpu) => return Err(QueueSubmitError::StuckGpu), }; @@ -751,13 +762,13 @@ impl Global { device.temp_suspected.clear(); device.lock_life(&mut token).post_submit(); - callbacks + closures }; - // the map callbacks should execute with nothing locked! - drop(token); - super::fire_map_callbacks(callbacks); - + // the closures should execute with nothing locked! + unsafe { + callbacks.fire(); + } Ok(()) } @@ -773,6 +784,29 @@ impl Global { Err(_) => Err(InvalidQueue), } } + + pub fn queue_on_submitted_work_done( + &self, + queue_id: id::QueueId, + closure: SubmittedWorkDoneClosure, + ) -> Result<(), InvalidQueue> { + //TODO: flush pending writes + let added = { + let hub = A::hub(self); + let mut token = Token::root(); + let (device_guard, mut token) = hub.devices.read(&mut token); + match device_guard.get(queue_id) { + Ok(device) => device.lock_life(&mut token).add_work_done_closure(closure), + Err(_) => return Err(InvalidQueue), + } + }; + if !added { + unsafe { + (closure.callback)(closure.user_data); + } + } + Ok(()) + } } fn get_lowest_common_denom(a: u32, b: u32) -> u32 { diff --git a/wgpu/src/backend/direct.rs b/wgpu/src/backend/direct.rs index b3e4530f39..b2fc8980b0 100644 --- a/wgpu/src/backend/direct.rs +++ b/wgpu/src/backend/direct.rs @@ -732,6 +732,7 @@ impl crate::Context for Context { type RequestDeviceFuture = Ready>; type MapAsyncFuture = native_gpu_future::GpuFuture>; + type OnSubmittedWorkDoneFuture = native_gpu_future::GpuFuture<()>; fn init(backends: wgt::Backends) -> Self { Self(wgc::hub::Global::new( @@ -2055,6 +2056,31 @@ impl crate::Context for Context { } } + fn queue_on_submitted_work_done( + &self, + queue: &Self::QueueId, + ) -> Self::OnSubmittedWorkDoneFuture { + let (future, completion) = native_gpu_future::new_gpu_future(); + + extern "C" fn submitted_work_done_future_wrapper(user_data: *mut u8) { + let completion = + unsafe { native_gpu_future::GpuFutureCompletion::from_raw(user_data as _) }; + completion.complete(()) + } + + let closure = wgc::device::queue::SubmittedWorkDoneClosure { + callback: submitted_work_done_future_wrapper, + user_data: completion.into_raw() as _, + }; + + let global = &self.0; + let res = wgc::gfx_select!(queue => global.queue_on_submitted_work_done(*queue, closure)); + if let Err(cause) = res { + self.handle_error_fatal(cause, "Queue::on_submitted_work_done"); + } + future + } + fn device_start_capture(&self, device: &Self::DeviceId) { let global = &self.0; wgc::gfx_select!(device.id => global.device_start_capture(device.id)); diff --git a/wgpu/src/backend/web.rs b/wgpu/src/backend/web.rs index 62e829d317..7b678f3eb6 100644 --- a/wgpu/src/backend/web.rs +++ b/wgpu/src/backend/web.rs @@ -925,6 +925,8 @@ impl crate::Context for Context { wasm_bindgen_futures::JsFuture, fn(JsFutureResult) -> Result<(), crate::BufferAsyncError>, >; + type OnSubmittedWorkDoneFuture = + MakeSendFuture ()>; fn init(_backends: wgt::Backends) -> Self { Context(web_sys::window().unwrap().navigator().gpu()) @@ -2035,6 +2037,13 @@ impl crate::Context for Context { 1.0 //TODO } + fn queue_on_submitted_work_done( + &self, + _queue: &Self::QueueId, + ) -> Self::OnSubmittedWorkDoneFuture { + unimplemented!() + } + fn device_start_capture(&self, _device: &Self::DeviceId) {} fn device_stop_capture(&self, _device: &Self::DeviceId) {} } diff --git a/wgpu/src/lib.rs b/wgpu/src/lib.rs index 5da8ddef2a..9fb52697b7 100644 --- a/wgpu/src/lib.rs +++ b/wgpu/src/lib.rs @@ -182,6 +182,7 @@ trait Context: Debug + Send + Sized + Sync { type RequestDeviceFuture: Future> + Send; type MapAsyncFuture: Future> + Send; + type OnSubmittedWorkDoneFuture: Future + Send; fn init(backends: Backends) -> Self; fn instance_create_surface( @@ -466,6 +467,10 @@ trait Context: Debug + Send + Sized + Sync { command_buffers: I, ); fn queue_get_timestamp_period(&self, queue: &Self::QueueId) -> f32; + fn queue_on_submitted_work_done( + &self, + queue: &Self::QueueId, + ) -> Self::OnSubmittedWorkDoneFuture; fn device_start_capture(&self, device: &Self::DeviceId); fn device_stop_capture(&self, device: &Self::DeviceId); @@ -3071,6 +3076,12 @@ impl Queue { pub fn get_timestamp_period(&self) -> f32 { Context::queue_get_timestamp_period(&*self.context, &self.id) } + + /// Returns a future that resolves once all the work submitted by this point + /// is done processing on GPU. + pub fn on_submitted_work_done(&self) -> impl Future + Send { + Context::queue_on_submitted_work_done(&*self.context, &self.id) + } } impl Drop for SwapChainTexture {