diff --git a/wgpu-core/src/command/bundle.rs b/wgpu-core/src/command/bundle.rs index 82648e0e1c..653ed91303 100644 --- a/wgpu-core/src/command/bundle.rs +++ b/wgpu-core/src/command/bundle.rs @@ -785,6 +785,8 @@ impl RenderBundle { } } + let snatch_guard = self.device.snatchable_lock.read(); + for command in self.base.commands.iter() { match *command { RenderCommand::SetBindGroup { @@ -818,7 +820,7 @@ impl RenderBundle { size, } => { let buffers = trackers.buffers.read(); - let buffer = buffers.get(buffer_id).unwrap().raw(); + let buffer = buffers.get(buffer_id).unwrap().raw(&snatch_guard); let bb = hal::BufferBinding { buffer, offset, @@ -833,7 +835,7 @@ impl RenderBundle { size, } => { let buffers = trackers.buffers.read(); - let buffer = buffers.get(buffer_id).unwrap().raw(); + let buffer = buffers.get(buffer_id).unwrap().raw(&snatch_guard); let bb = hal::BufferBinding { buffer, offset, @@ -912,7 +914,7 @@ impl RenderBundle { indexed: false, } => { let buffers = trackers.buffers.read(); - let buffer = buffers.get(buffer_id).unwrap().raw(); + let buffer = buffers.get(buffer_id).unwrap().raw(&snatch_guard); unsafe { raw.draw_indirect(buffer, offset, 1) }; } RenderCommand::MultiDrawIndirect { @@ -922,7 +924,7 @@ impl RenderBundle { indexed: true, } => { let buffers = trackers.buffers.read(); - let buffer = buffers.get(buffer_id).unwrap().raw(); + let buffer = buffers.get(buffer_id).unwrap().raw(&snatch_guard); unsafe { raw.draw_indexed_indirect(buffer, offset, 1) }; } RenderCommand::MultiDrawIndirect { .. } diff --git a/wgpu-core/src/command/clear.rs b/wgpu-core/src/command/clear.rs index d2dd1bbe1a..b58738999e 100644 --- a/wgpu-core/src/command/clear.rs +++ b/wgpu-core/src/command/clear.rs @@ -102,9 +102,10 @@ impl Global { .set_single(dst_buffer, hal::BufferUses::COPY_DST) .ok_or(ClearError::InvalidBuffer(dst))? }; + let snatch_guard = dst_buffer.device.snatchable_lock.read(); let dst_raw = dst_buffer .raw - .as_ref() + .get(&snatch_guard) .ok_or(ClearError::InvalidBuffer(dst))?; if !dst_buffer.usage.contains(BufferUsages::COPY_DST) { return Err(ClearError::MissingCopyDstUsageFlag(Some(dst), None)); @@ -145,8 +146,9 @@ impl Global { MemoryInitKind::ImplicitlyInitialized, ), ); + // actual hal barrier & operation - let dst_barrier = dst_pending.map(|pending| pending.into_hal(&dst_buffer)); + let dst_barrier = dst_pending.map(|pending| pending.into_hal(&dst_buffer, &snatch_guard)); let cmd_buf_raw = cmd_buf_data.encoder.open(); unsafe { cmd_buf_raw.transition_buffers(dst_barrier.into_iter()); diff --git a/wgpu-core/src/command/compute.rs b/wgpu-core/src/command/compute.rs index c457a33186..8bc8f43cd4 100644 --- a/wgpu-core/src/command/compute.rs +++ b/wgpu-core/src/command/compute.rs @@ -1,4 +1,5 @@ use crate::resource::Resource; +use crate::snatch::SnatchGuard; use crate::{ binding_model::{ BindError, BindGroup, LateMinBufferBindingSizeMismatch, PushConstantUploadError, @@ -310,6 +311,7 @@ impl State { base_trackers: &mut Tracker, bind_group_guard: &Storage, id::BindGroupId>, indirect_buffer: Option, + snatch_guard: &SnatchGuard, ) -> Result<(), UsageConflict> { for id in self.binder.list_active() { unsafe { self.scope.merge_bind_group(&bind_group_guard[id].used)? }; @@ -335,7 +337,7 @@ impl State { log::trace!("Encoding dispatch barriers"); - CommandBuffer::drain_barriers(raw_encoder, base_trackers); + CommandBuffer::drain_barriers(raw_encoder, base_trackers, snatch_guard); Ok(()) } } @@ -453,6 +455,8 @@ impl Global { None }; + let snatch_guard = device.snatchable_lock.read(); + tracker.set_size( Some(&*buffer_guard), Some(&*texture_guard), @@ -673,7 +677,13 @@ impl Global { state.is_ready().map_pass_err(scope)?; state - .flush_states(raw, &mut intermediate_trackers, &*bind_group_guard, None) + .flush_states( + raw, + &mut intermediate_trackers, + &*bind_group_guard, + None, + &snatch_guard, + ) .map_pass_err(scope)?; let groups_size_limit = cmd_buf.limits.max_compute_workgroups_per_dimension; @@ -727,7 +737,7 @@ impl Global { let buf_raw = indirect_buffer .raw - .as_ref() + .get(&snatch_guard) .ok_or(ComputePassErrorInner::InvalidIndirectBuffer(buffer_id)) .map_pass_err(scope)?; @@ -747,6 +757,7 @@ impl Global { &mut intermediate_trackers, &*bind_group_guard, Some(buffer_id), + &snatch_guard, ) .map_pass_err(scope)?; unsafe { @@ -860,7 +871,12 @@ impl Global { &mut tracker.textures, device, ); - CommandBuffer::insert_barriers_from_tracker(transit, tracker, &intermediate_trackers); + CommandBuffer::insert_barriers_from_tracker( + transit, + tracker, + &intermediate_trackers, + &snatch_guard, + ); // Close the command buffer, and swap it with the previous. encoder.close_and_swap(); diff --git a/wgpu-core/src/command/memory_init.rs b/wgpu-core/src/command/memory_init.rs index f10c85e2be..81456ece0a 100644 --- a/wgpu-core/src/command/memory_init.rs +++ b/wgpu-core/src/command/memory_init.rs @@ -223,12 +223,16 @@ impl BakedCommands { .unwrap() .1; - let raw_buf = buffer.raw.as_ref().ok_or(DestroyedBufferError(buffer_id))?; + let snatch_guard = buffer.device.snatchable_lock.read(); + let raw_buf = buffer + .raw + .get(&snatch_guard) + .ok_or(DestroyedBufferError(buffer_id))?; unsafe { self.encoder.transition_buffers( transition - .map(|pending| pending.into_hal(&buffer)) + .map(|pending| pending.into_hal(&buffer, &snatch_guard)) .into_iter(), ); } diff --git a/wgpu-core/src/command/mod.rs b/wgpu-core/src/command/mod.rs index fce9d2c6a1..64f39030b6 100644 --- a/wgpu-core/src/command/mod.rs +++ b/wgpu-core/src/command/mod.rs @@ -22,6 +22,7 @@ use crate::device::Device; use crate::error::{ErrorFormatter, PrettyError}; use crate::hub::Hub; use crate::id::CommandBufferId; +use crate::snatch::SnatchGuard; use crate::init_tracker::BufferInitTrackerAction; use crate::resource::{Resource, ResourceInfo, ResourceType}; @@ -193,32 +194,38 @@ impl CommandBuffer { raw: &mut A::CommandEncoder, base: &mut Tracker, head: &Tracker, + snatch_guard: &SnatchGuard, ) { profiling::scope!("insert_barriers"); base.buffers.set_from_tracker(&head.buffers); base.textures.set_from_tracker(&head.textures); - Self::drain_barriers(raw, base); + Self::drain_barriers(raw, base, snatch_guard); } pub(crate) fn insert_barriers_from_scope( raw: &mut A::CommandEncoder, base: &mut Tracker, head: &UsageScope, + snatch_guard: &SnatchGuard, ) { profiling::scope!("insert_barriers"); base.buffers.set_from_usage_scope(&head.buffers); base.textures.set_from_usage_scope(&head.textures); - Self::drain_barriers(raw, base); + Self::drain_barriers(raw, base, snatch_guard); } - pub(crate) fn drain_barriers(raw: &mut A::CommandEncoder, base: &mut Tracker) { + pub(crate) fn drain_barriers( + raw: &mut A::CommandEncoder, + base: &mut Tracker, + snatch_guard: &SnatchGuard, + ) { profiling::scope!("drain_barriers"); - let buffer_barriers = base.buffers.drain_transitions(); + let buffer_barriers = base.buffers.drain_transitions(snatch_guard); let (transitions, textures) = base.textures.drain_transitions(); let texture_barriers = transitions .into_iter() diff --git a/wgpu-core/src/command/query.rs b/wgpu-core/src/command/query.rs index 3d4404b41e..bb6136b544 100644 --- a/wgpu-core/src/command/query.rs +++ b/wgpu-core/src/command/query.rs @@ -430,7 +430,10 @@ impl Global { .set_single(dst_buffer, hal::BufferUses::COPY_DST) .ok_or(QueryError::InvalidBuffer(destination))? }; - let dst_barrier = dst_pending.map(|pending| pending.into_hal(&dst_buffer)); + + let snatch_guard = dst_buffer.device.snatchable_lock.read(); + + let dst_barrier = dst_pending.map(|pending| pending.into_hal(&dst_buffer, &snatch_guard)); if !dst_buffer.usage.contains(wgt::BufferUsages::QUERY_RESOLVE) { return Err(ResolveError::MissingBufferUsage.into()); @@ -481,7 +484,7 @@ impl Global { raw_encoder.copy_query_results( query_set.raw(), start_query..end_query, - dst_buffer.raw(), + dst_buffer.raw(&snatch_guard), destination_offset, wgt::BufferSize::new_unchecked(stride as u64), ); diff --git a/wgpu-core/src/command/render.rs b/wgpu-core/src/command/render.rs index a2bb291a00..0ed8f46aa9 100644 --- a/wgpu-core/src/command/render.rs +++ b/wgpu-core/src/command/render.rs @@ -1308,8 +1308,11 @@ impl Global { let hub = A::hub(self); + let cmd_buf = CommandBuffer::get_encoder(hub, encoder_id).map_pass_err(init_scope)?; + let device = &cmd_buf.device; + let snatch_guard = device.snatchable_lock.read(); + let (scope, pending_discard_init_fixups) = { - let cmd_buf = CommandBuffer::get_encoder(hub, encoder_id).map_pass_err(init_scope)?; let mut cmd_buf_data = cmd_buf.data.lock(); let cmd_buf_data = cmd_buf_data.as_mut().unwrap(); @@ -1324,7 +1327,6 @@ impl Global { }); } - let device = &cmd_buf.device; if !device.is_valid() { return Err(DeviceError::Lost).map_pass_err(init_scope); } @@ -1639,7 +1641,7 @@ impl Global { .map_pass_err(scope)?; let buf_raw = buffer .raw - .as_ref() + .get(&snatch_guard) .ok_or(RenderCommandError::DestroyedBuffer(buffer_id)) .map_pass_err(scope)?; @@ -1701,7 +1703,7 @@ impl Global { .map_pass_err(scope)?; let buf_raw = buffer .raw - .as_ref() + .get(&snatch_guard) .ok_or(RenderCommandError::DestroyedBuffer(buffer_id)) .map_pass_err(scope)?; @@ -1991,7 +1993,7 @@ impl Global { .map_pass_err(scope)?; let indirect_raw = indirect_buffer .raw - .as_ref() + .get(&snatch_guard) .ok_or(RenderCommandError::DestroyedBuffer(buffer_id)) .map_pass_err(scope)?; @@ -2063,7 +2065,7 @@ impl Global { .map_pass_err(scope)?; let indirect_raw = indirect_buffer .raw - .as_ref() + .get(&snatch_guard) .ok_or(RenderCommandError::DestroyedBuffer(buffer_id)) .map_pass_err(scope)?; @@ -2080,7 +2082,7 @@ impl Global { .map_pass_err(scope)?; let count_raw = count_buffer .raw - .as_ref() + .get(&snatch_guard) .ok_or(RenderCommandError::DestroyedBuffer(count_buffer_id)) .map_pass_err(scope)?; @@ -2385,7 +2387,12 @@ impl Global { .map_err(RenderCommandError::InvalidQuerySet) .map_pass_err(PassErrorScope::QueryReset)?; - super::CommandBuffer::insert_barriers_from_scope(transit, tracker, &scope); + super::CommandBuffer::insert_barriers_from_scope( + transit, + tracker, + &scope, + &snatch_guard, + ); } *status = CommandEncoderStatus::Recording; diff --git a/wgpu-core/src/command/transfer.rs b/wgpu-core/src/command/transfer.rs index 7ec1bcaebd..359bfbfae3 100644 --- a/wgpu-core/src/command/transfer.rs +++ b/wgpu-core/src/command/transfer.rs @@ -597,6 +597,8 @@ impl Global { }); } + let snatch_guard = device.snatchable_lock.read(); + let (src_buffer, src_pending) = { let buffer_guard = hub.buffers.read(); let src_buffer = buffer_guard @@ -610,13 +612,13 @@ impl Global { }; let src_raw = src_buffer .raw - .as_ref() + .get(&snatch_guard) .ok_or(TransferError::InvalidBuffer(source))?; if !src_buffer.usage.contains(BufferUsages::COPY_SRC) { return Err(TransferError::MissingCopySrcUsageFlag.into()); } // expecting only a single barrier - let src_barrier = src_pending.map(|pending| pending.into_hal(&src_buffer)); + let src_barrier = src_pending.map(|pending| pending.into_hal(&src_buffer, &snatch_guard)); let (dst_buffer, dst_pending) = { let buffer_guard = hub.buffers.read(); @@ -631,12 +633,12 @@ impl Global { }; let dst_raw = dst_buffer .raw - .as_ref() + .get(&snatch_guard) .ok_or(TransferError::InvalidBuffer(destination))?; if !dst_buffer.usage.contains(BufferUsages::COPY_DST) { return Err(TransferError::MissingCopyDstUsageFlag(Some(destination), None).into()); } - let dst_barrier = dst_pending.map(|pending| pending.into_hal(&dst_buffer)); + let dst_barrier = dst_pending.map(|pending| pending.into_hal(&dst_buffer, &snatch_guard)); if size % wgt::COPY_BUFFER_ALIGNMENT != 0 { return Err(TransferError::UnalignedCopySize(size).into()); @@ -795,6 +797,8 @@ impl Global { &texture_guard, )?; + let snatch_guard = device.snatchable_lock.read(); + let (src_buffer, src_pending) = { let buffer_guard = hub.buffers.read(); let src_buffer = buffer_guard @@ -807,12 +811,12 @@ impl Global { }; let src_raw = src_buffer .raw - .as_ref() + .get(&snatch_guard) .ok_or(TransferError::InvalidBuffer(source.buffer))?; if !src_buffer.usage.contains(BufferUsages::COPY_SRC) { return Err(TransferError::MissingCopySrcUsageFlag.into()); } - let src_barrier = src_pending.map(|pending| pending.into_hal(&src_buffer)); + let src_barrier = src_pending.map(|pending| pending.into_hal(&src_buffer, &snatch_guard)); let dst_pending = tracker .textures @@ -953,6 +957,8 @@ impl Global { &texture_guard, )?; + let snatch_guard = device.snatchable_lock.read(); + let src_pending = tracker .textures .set_single(src_texture, src_range, hal::TextureUses::COPY_SRC) @@ -993,14 +999,14 @@ impl Global { }; let dst_raw = dst_buffer .raw - .as_ref() + .get(&snatch_guard) .ok_or(TransferError::InvalidBuffer(destination.buffer))?; if !dst_buffer.usage.contains(BufferUsages::COPY_DST) { return Err( TransferError::MissingCopyDstUsageFlag(Some(destination.buffer), None).into(), ); } - let dst_barrier = dst_pending.map(|pending| pending.into_hal(&dst_buffer)); + let dst_barrier = dst_pending.map(|pending| pending.into_hal(&dst_buffer, &snatch_guard)); if !src_base.aspect.is_one() { return Err(TransferError::CopyAspectNotOne.into()); diff --git a/wgpu-core/src/device/global.rs b/wgpu-core/src/device/global.rs index 7a6caec6c9..a55d1563ef 100644 --- a/wgpu-core/src/device/global.rs +++ b/wgpu-core/src/device/global.rs @@ -239,7 +239,12 @@ impl Global { let stage_fid = hub.buffers.request(); let stage = stage_fid.init(stage); - let mapping = match unsafe { device.raw().map_buffer(stage.raw(), 0..stage.size) } { + let snatch_guard = device.snatchable_lock.read(); + let mapping = match unsafe { + device + .raw() + .map_buffer(stage.raw(&snatch_guard), 0..stage.size) + } { Ok(mapping) => mapping, Err(e) => { let mut life_lock = device.lock_life(); @@ -380,6 +385,7 @@ impl Global { .devices .get(device_id) .map_err(|_| DeviceError::Invalid)?; + let snatch_guard = device.snatchable_lock.read(); if !device.is_valid() { return Err(DeviceError::Lost.into()); } @@ -402,7 +408,7 @@ impl Global { }); } - let raw_buf = buffer.raw(); + let raw_buf = buffer.raw(&snatch_guard); unsafe { let mapping = device .raw() @@ -443,6 +449,8 @@ impl Global { return Err(DeviceError::Lost.into()); } + let snatch_guard = device.snatchable_lock.read(); + let buffer = hub .buffers .get(buffer_id) @@ -450,7 +458,7 @@ impl Global { check_buffer_usage(buffer.usage, wgt::BufferUsages::MAP_READ)?; //assert!(buffer isn't used by the GPU); - let raw_buf = buffer.raw(); + let raw_buf = buffer.raw(&snatch_guard); unsafe { let mapping = device .raw() @@ -2406,7 +2414,8 @@ impl Global { let mut trackers = buffer.device.as_ref().trackers.lock(); trackers.buffers.set_single(&buffer, internal_use); //TODO: Check if draining ALL buffers is correct! - let _ = trackers.buffers.drain_transitions(); + let snatch_guard = device.snatchable_lock.read(); + let _ = trackers.buffers.drain_transitions(&snatch_guard); } buffer diff --git a/wgpu-core/src/device/mod.rs b/wgpu-core/src/device/mod.rs index b8ebaf46c8..6cc8f5acef 100644 --- a/wgpu-core/src/device/mod.rs +++ b/wgpu-core/src/device/mod.rs @@ -294,14 +294,18 @@ fn map_buffer( size: BufferAddress, kind: HostMap, ) -> Result, BufferAccessError> { + let snatch_guard = buffer.device.snatchable_lock.read(); let mapping = unsafe { - raw.map_buffer(buffer.raw(), offset..offset + size) + raw.map_buffer(buffer.raw(&snatch_guard), offset..offset + size) .map_err(DeviceError::from)? }; *buffer.sync_mapped_writes.lock() = match kind { HostMap::Read if !mapping.is_coherent => unsafe { - raw.invalidate_mapped_ranges(buffer.raw(), iter::once(offset..offset + size)); + raw.invalidate_mapped_ranges( + buffer.raw(&snatch_guard), + iter::once(offset..offset + size), + ); None }, HostMap::Write if !mapping.is_coherent => Some(offset..offset + size), @@ -341,7 +345,9 @@ fn map_buffer( mapped[fill_range].fill(0); if zero_init_needs_flush_now { - unsafe { raw.flush_mapped_ranges(buffer.raw(), iter::once(uninitialized)) }; + unsafe { + raw.flush_mapped_ranges(buffer.raw(&snatch_guard), iter::once(uninitialized)) + }; } } diff --git a/wgpu-core/src/device/queue.rs b/wgpu-core/src/device/queue.rs index 9a1918c0a1..2552369649 100644 --- a/wgpu-core/src/device/queue.rs +++ b/wgpu-core/src/device/queue.rs @@ -593,9 +593,10 @@ impl Global { .set_single(dst, hal::BufferUses::COPY_DST) .ok_or(TransferError::InvalidBuffer(buffer_id))? }; + let snatch_guard = device.snatchable_lock.read(); let dst_raw = dst .raw - .as_ref() + .get(&snatch_guard) .ok_or(TransferError::InvalidBuffer(buffer_id))?; if dst.device.as_info().id() != device.as_info().id() { @@ -618,7 +619,7 @@ impl Global { buffer: inner_buffer.as_ref().unwrap(), usage: hal::BufferUses::MAP_WRITE..hal::BufferUses::COPY_SRC, }) - .chain(transition.map(|pending| pending.into_hal(&dst))); + .chain(transition.map(|pending| pending.into_hal(&dst, &snatch_guard))); let encoder = pending_writes.activate(); unsafe { encoder.transition_buffers(barriers); @@ -1139,6 +1140,8 @@ impl Global { let mut active_executions = Vec::new(); let mut used_surface_textures = track::TextureUsageScope::new(); + let snatch_guard = device.snatchable_lock.read(); + { let mut command_buffer_guard = hub.command_buffers.write(); @@ -1207,8 +1210,8 @@ impl Global { // update submission IDs for buffer in cmd_buf_trackers.buffers.used_resources() { let id = buffer.info.id(); - let raw_buf = match buffer.raw { - Some(ref raw) => raw, + let raw_buf = match buffer.raw.get(&snatch_guard) { + Some(raw) => raw, None => { return Err(QueueSubmitError::DestroyedBuffer(id)); } @@ -1362,6 +1365,7 @@ impl Global { &mut baked.encoder, &mut *trackers, &baked.trackers, + &snatch_guard, ); let transit = unsafe { baked.encoder.end_encoding().unwrap() }; diff --git a/wgpu-core/src/device/resource.rs b/wgpu-core/src/device/resource.rs index a3dd162bfb..9e27888f39 100644 --- a/wgpu-core/src/device/resource.rs +++ b/wgpu-core/src/device/resource.rs @@ -26,6 +26,7 @@ use crate::{ TextureViewNotRenderableReason, }, resource_log, + snatch::{SnatchGuard, SnatchLock, Snatchable}, storage::Storage, track::{BindGroupStates, TextureSelector, Tracker}, validation::{self, check_buffer_usage, check_texture_usage}, @@ -94,6 +95,7 @@ pub struct Device { //Note: The submission index here corresponds to the last submission that is done. pub(crate) active_submission_index: AtomicU64, //SubmissionIndex, pub(crate) fence: RwLock>, + pub(crate) snatchable_lock: SnatchLock, /// Is this device valid? Valid is closely associated with "lose the device", /// which can be triggered by various methods, including at the end of device @@ -254,6 +256,7 @@ impl Device { command_allocator: Mutex::new(Some(com_alloc)), active_submission_index: AtomicU64::new(0), fence: RwLock::new(Some(fence)), + snatchable_lock: unsafe { SnatchLock::new() }, valid: AtomicBool::new(true), trackers: Mutex::new(Tracker::new()), life_tracker: Mutex::new(life::LifetimeTracker::new()), @@ -538,7 +541,7 @@ impl Device { let buffer = unsafe { self.raw().create_buffer(&hal_desc) }.map_err(DeviceError::from)?; Ok(Buffer { - raw: Some(buffer), + raw: Snatchable::new(buffer), device: self.clone(), usage: desc.usage, size: desc.size, @@ -588,7 +591,7 @@ impl Device { debug_assert_eq!(self.as_info().id().backend(), A::VARIANT); Buffer { - raw: Some(hal_buffer), + raw: Snatchable::new(hal_buffer), device: self.clone(), usage: desc.usage, size: desc.size, @@ -1766,6 +1769,7 @@ impl Device { used: &mut BindGroupStates, storage: &'a Storage, id::BufferId>, limits: &wgt::Limits, + snatch_guard: &'a SnatchGuard<'a>, ) -> Result, binding_model::CreateBindGroupError> { use crate::binding_model::CreateBindGroupError as Error; @@ -1818,7 +1822,7 @@ impl Device { check_buffer_usage(buffer.usage, pub_usage)?; let raw_buffer = buffer .raw - .as_ref() + .get(snatch_guard) .ok_or(Error::InvalidBuffer(bb.buffer_id))?; let (bind_size, bind_end) = match bb.size { @@ -1966,6 +1970,7 @@ impl Device { let mut hal_buffers = Vec::new(); let mut hal_samplers = Vec::new(); let mut hal_textures = Vec::new(); + let snatch_guard = self.snatchable_lock.read(); for entry in desc.entries.iter() { let binding = entry.binding; // Find the corresponding declaration in the layout @@ -1985,6 +1990,7 @@ impl Device { &mut used, &*buffer_guard, &self.limits, + &snatch_guard, )?; let res_index = hal_buffers.len(); @@ -2007,6 +2013,7 @@ impl Device { &mut used, &*buffer_guard, &self.limits, + &snatch_guard, )?; hal_buffers.push(bb); } diff --git a/wgpu-core/src/lib.rs b/wgpu-core/src/lib.rs index e68debdcb8..a4bb0e71b5 100644 --- a/wgpu-core/src/lib.rs +++ b/wgpu-core/src/lib.rs @@ -64,6 +64,7 @@ pub mod pipeline; pub mod present; pub mod registry; pub mod resource; +mod snatch; pub mod storage; mod track; // This is public for users who pre-compile shaders while still wanting to diff --git a/wgpu-core/src/resource.rs b/wgpu-core/src/resource.rs index 587b479206..3e0f790d1e 100644 --- a/wgpu-core/src/resource.rs +++ b/wgpu-core/src/resource.rs @@ -14,6 +14,7 @@ use crate::{ identity::{GlobalIdentityHandlerFactory, IdentityManager}, init_tracker::{BufferInitTracker, TextureInitTracker}, resource, resource_log, + snatch::{SnatchGuard, Snatchable}, track::TextureSelector, validation::MissingBufferUsageError, Label, SubmissionIndex, @@ -395,7 +396,7 @@ pub type BufferDescriptor<'a> = wgt::BufferDescriptor>; #[derive(Debug)] pub struct Buffer { - pub(crate) raw: Option, + pub(crate) raw: Snatchable, pub(crate) device: Arc>, pub(crate) usage: wgt::BufferUsages, pub(crate) size: wgt::BufferAddress, @@ -418,8 +419,8 @@ impl Drop for Buffer { } impl Buffer { - pub(crate) fn raw(&self) -> &A::Buffer { - self.raw.as_ref().unwrap() + pub(crate) fn raw(&self, guard: &SnatchGuard) -> &A::Buffer { + self.raw.get(guard).unwrap() } pub(crate) fn buffer_unmap_inner( @@ -428,6 +429,7 @@ impl Buffer { use hal::Device; let device = &self.device; + let snatch_guard = device.snatchable_lock.read(); let buffer_id = self.info.id(); log::debug!("Buffer {:?} map state -> Idle", buffer_id); match mem::replace(&mut *self.map_state.lock(), resource::BufferMapState::Idle) { @@ -451,13 +453,17 @@ impl Buffer { let _ = ptr; if needs_flush { unsafe { - device - .raw() - .flush_mapped_ranges(stage_buffer.raw(), iter::once(0..self.size)); + device.raw().flush_mapped_ranges( + stage_buffer.raw(&snatch_guard), + iter::once(0..self.size), + ); } } - let raw_buf = self.raw.as_ref().ok_or(BufferAccessError::Destroyed)?; + let raw_buf = self + .raw + .get(&snatch_guard) + .ok_or(BufferAccessError::Destroyed)?; self.info .use_at(device.active_submission_index.load(Ordering::Relaxed) + 1); @@ -467,7 +473,7 @@ impl Buffer { size, }); let transition_src = hal::BufferBarrier { - buffer: stage_buffer.raw(), + buffer: stage_buffer.raw(&snatch_guard), usage: hal::BufferUses::MAP_WRITE..hal::BufferUses::COPY_SRC, }; let transition_dst = hal::BufferBarrier { @@ -483,7 +489,7 @@ impl Buffer { ); if self.size > 0 { encoder.copy_buffer_to_buffer( - stage_buffer.raw(), + stage_buffer.raw(&snatch_guard), raw_buf, region.into_iter(), ); @@ -518,7 +524,7 @@ impl Buffer { unsafe { device .raw() - .unmap_buffer(self.raw()) + .unmap_buffer(self.raw(&snatch_guard)) .map_err(DeviceError::from)? }; } @@ -539,7 +545,10 @@ impl Buffer { if let Some(ref mut trace) = *device.trace.lock() { trace.add(trace::Action::FreeBuffer(buffer_id)); } - if self.raw.is_none() { + // Note: a future commit will replace this with a read guard + // and actually snatch the buffer. + let snatch_guard = device.snatchable_lock.read(); + if self.raw.get(&snatch_guard).is_none() { return Err(resource::DestroyError::AlreadyDestroyed); } diff --git a/wgpu-core/src/snatch.rs b/wgpu-core/src/snatch.rs new file mode 100644 index 0000000000..db3114076c --- /dev/null +++ b/wgpu-core/src/snatch.rs @@ -0,0 +1,86 @@ +#![allow(unused)] + +use parking_lot::{RwLock, RwLockReadGuard, RwLockWriteGuard}; +use std::cell::UnsafeCell; + +/// A guard that provides read access to snatchable data. +pub struct SnatchGuard<'a>(RwLockReadGuard<'a, ()>); +/// A guard that allows snatching the snatchable data. +pub struct ExclusiveSnatchGuard<'a>(RwLockWriteGuard<'a, ()>); + +/// A value that is mostly immutable but can be "snatched" if we need to destroy +/// it early. +/// +/// In order to safely access the underlying data, the device's global snatchable +/// lock must be taken. To guarentee it, methods take a read or write guard of that +/// special lock. +pub struct Snatchable { + value: UnsafeCell>, +} + +impl Snatchable { + pub fn new(val: T) -> Self { + Snatchable { + value: UnsafeCell::new(Some(val)), + } + } + + /// Get read access to the value. Requires a the snatchable lock's read guard. + pub fn get(&self, _guard: &SnatchGuard) -> Option<&T> { + unsafe { (*self.value.get()).as_ref() } + } + + /// Take the value. Requires a the snatchable lock's write guard. + pub fn snatch(&self, _guard: ExclusiveSnatchGuard) -> Option { + unsafe { (*self.value.get()).take() } + } + + /// Take the value without a guard. This can only be used with exclusive access + /// to self, so it does not require locking. + /// + /// Typically useful in a drop implementation. + pub fn take(&mut self) -> Option { + self.value.get_mut().take() + } +} + +// Can't safely print the contents of a snatchable object without holding +// the lock. +impl std::fmt::Debug for Snatchable { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "") + } +} + +unsafe impl Sync for Snatchable {} + +/// A Device-global lock for all snatchable data. +pub struct SnatchLock { + lock: RwLock<()>, +} + +impl SnatchLock { + /// The safety of `Snatchable::get` and `Snatchable::snatch` rely on their using of the + /// right SnatchLock (the one associated to the same device). This method is unsafe + /// to force force sers to think twice about creating a SnatchLock. The only place this + /// method sould be called is when creating the device. + pub unsafe fn new() -> Self { + SnatchLock { + lock: RwLock::new(()), + } + } + + /// Request read access to snatchable resources. + pub fn read(&self) -> SnatchGuard { + SnatchGuard(self.lock.read()) + } + + /// Request write access to snatchable resources. + /// + /// This should only be called when a resource needs to be snatched. This has + /// a high risk of causing lock contention if called concurrently with other + /// wgpu work. + pub fn write(&self) -> ExclusiveSnatchGuard { + ExclusiveSnatchGuard(self.lock.write()) + } +} diff --git a/wgpu-core/src/track/buffer.rs b/wgpu-core/src/track/buffer.rs index 4b2a9a8abe..2c2a6937f9 100644 --- a/wgpu-core/src/track/buffer.rs +++ b/wgpu-core/src/track/buffer.rs @@ -12,6 +12,7 @@ use crate::{ hal_api::HalApi, id::{BufferId, TypedId}, resource::{Buffer, Resource}, + snatch::SnatchGuard, storage::Storage, track::{ invalid_resource_state, skip_barrier, ResourceMetadata, ResourceMetadataProvider, @@ -387,10 +388,13 @@ impl BufferTracker { } /// Drains all currently pending transitions. - pub fn drain_transitions(&mut self) -> impl Iterator> { + pub fn drain_transitions<'a, 'b: 'a>( + &'b mut self, + snatch_guard: &'a SnatchGuard<'a>, + ) -> impl Iterator> { let buffer_barriers = self.temp.drain(..).map(|pending| { let buf = unsafe { self.metadata.get_resource_unchecked(pending.id as _) }; - pending.into_hal(buf) + pending.into_hal(buf, snatch_guard) }); buffer_barriers } diff --git a/wgpu-core/src/track/mod.rs b/wgpu-core/src/track/mod.rs index c2d5dc3c20..0f0b22b004 100644 --- a/wgpu-core/src/track/mod.rs +++ b/wgpu-core/src/track/mod.rs @@ -106,6 +106,7 @@ use crate::{ hal_api::HalApi, id::{self, TypedId}, pipeline, resource, + snatch::SnatchGuard, storage::Storage, }; @@ -138,8 +139,9 @@ impl PendingTransition { pub fn into_hal<'a, A: HalApi>( self, buf: &'a resource::Buffer, + snatch_guard: &'a SnatchGuard<'a>, ) -> hal::BufferBarrier<'a, A> { - let buffer = buf.raw.as_ref().expect("Buffer is destroyed"); + let buffer = buf.raw.get(snatch_guard).expect("Buffer is destroyed"); hal::BufferBarrier { buffer, usage: self.usage,