From 98426329a40e2482770238e9bad78dacff3db898 Mon Sep 17 00:00:00 2001 From: teoxoy <28601907+teoxoy@users.noreply.github.com> Date: Thu, 5 Sep 2024 14:30:41 +0200 Subject: [PATCH] [wgpu-core] introduce `Registry` `.strict_get()` & `.strict_unregister()` and use them for adapters This works because we never assign errors to adapters (they are never invalid). --- deno_webgpu/lib.rs | 6 +-- player/src/bin/play.rs | 2 +- player/tests/test.rs | 4 +- wgpu-core/src/device/global.rs | 13 ++---- wgpu-core/src/instance.rs | 80 +++++++++------------------------- wgpu-core/src/registry.rs | 13 ++++++ wgpu-core/src/resource.rs | 7 +-- wgpu-core/src/storage.rs | 29 ++++++++++++ wgpu/src/backend/wgpu_core.rs | 32 +++----------- 9 files changed, 81 insertions(+), 105 deletions(-) diff --git a/deno_webgpu/lib.rs b/deno_webgpu/lib.rs index c2dfb240f..c64cb4a06 100644 --- a/deno_webgpu/lib.rs +++ b/deno_webgpu/lib.rs @@ -414,9 +414,9 @@ pub fn op_webgpu_request_adapter( }) } }; - let adapter_features = instance.adapter_features(adapter)?; + let adapter_features = instance.adapter_features(adapter); let features = deserialize_features(&adapter_features); - let adapter_limits = instance.adapter_limits(adapter)?; + let adapter_limits = instance.adapter_limits(adapter); let instance = instance.clone(); @@ -705,7 +705,7 @@ pub fn op_webgpu_request_adapter_info( let adapter = adapter_resource.1; let instance = state.borrow::(); - let info = instance.adapter_get_info(adapter)?; + let info = instance.adapter_get_info(adapter); adapter_resource.close(); Ok(GPUAdapterInfo { diff --git a/player/src/bin/play.rs b/player/src/bin/play.rs index 4726fe63a..f0198f130 100644 --- a/player/src/bin/play.rs +++ b/player/src/bin/play.rs @@ -78,7 +78,7 @@ fn main() { ) .expect("Unable to find an adapter for selected backend"); - let info = global.adapter_get_info(adapter).unwrap(); + let info = global.adapter_get_info(adapter); log::info!("Picked '{}'", info.name); let device_id = wgc::id::Id::zip(1, 0, backend); let queue_id = wgc::id::Id::zip(1, 0, backend); diff --git a/player/tests/test.rs b/player/tests/test.rs index 367a7494a..653a6fc9b 100644 --- a/player/tests/test.rs +++ b/player/tests/test.rs @@ -244,8 +244,8 @@ impl Corpus { }; println!("\tBackend {:?}", backend); - let supported_features = global.adapter_features(adapter).unwrap(); - let downlevel_caps = global.adapter_downlevel_capabilities(adapter).unwrap(); + let supported_features = global.adapter_features(adapter); + let downlevel_caps = global.adapter_downlevel_capabilities(adapter); let test = Test::load(dir.join(test_path), adapter.backend()); if !supported_features.contains(test.features) { diff --git a/wgpu-core/src/device/global.rs b/wgpu-core/src/device/global.rs index ba5c99258..c22010510 100644 --- a/wgpu-core/src/device/global.rs +++ b/wgpu-core/src/device/global.rs @@ -43,10 +43,7 @@ impl Global { let hub = &self.hub; let surface_guard = self.surfaces.read(); - let adapter_guard = hub.adapters.read(); - let adapter = adapter_guard - .get(adapter_id) - .map_err(|_| instance::IsSurfaceSupportedError::InvalidAdapter)?; + let adapter = hub.adapters.strict_get(adapter_id); let surface = surface_guard .get(surface_id) .map_err(|_| instance::IsSurfaceSupportedError::InvalidSurface)?; @@ -87,15 +84,13 @@ impl Global { let hub = &self.hub; let surface_guard = self.surfaces.read(); - let adapter_guard = hub.adapters.read(); - let adapter = adapter_guard - .get(adapter_id) - .map_err(|_| instance::GetSurfaceSupportError::InvalidAdapter)?; + let adapter = hub.adapters.strict_get(adapter_id); + let surface = surface_guard .get(surface_id) .map_err(|_| instance::GetSurfaceSupportError::InvalidSurface)?; - get_supported_callback(adapter, surface) + get_supported_callback(&adapter, surface) } pub fn device_features(&self, device_id: DeviceId) -> Result { diff --git a/wgpu-core/src/instance.rs b/wgpu-core/src/instance.rs index ca19d3c53..6f9292672 100644 --- a/wgpu-core/src/instance.rs +++ b/wgpu-core/src/instance.rs @@ -349,8 +349,6 @@ crate::impl_storage_item!(Adapter); #[derive(Clone, Debug, Error)] #[non_exhaustive] pub enum IsSurfaceSupportedError { - #[error("Invalid adapter")] - InvalidAdapter, #[error("Invalid surface")] InvalidSurface, } @@ -358,8 +356,6 @@ pub enum IsSurfaceSupportedError { #[derive(Clone, Debug, Error)] #[non_exhaustive] pub enum GetSurfaceSupportError { - #[error("Invalid adapter")] - InvalidAdapter, #[error("Invalid surface")] InvalidSurface, #[error("Surface is not supported by the adapter")] @@ -373,8 +369,6 @@ pub enum GetSurfaceSupportError { pub enum RequestDeviceError { #[error(transparent)] Device(#[from] DeviceError), - #[error("Parent adapter is invalid")] - InvalidAdapter, #[error(transparent)] LimitsExceeded(#[from] FailedLimit), #[error("Device has no queue supporting graphics")] @@ -403,10 +397,6 @@ impl AdapterInputs<'_, M> { } } -#[derive(Clone, Debug, Error)] -#[error("Adapter is invalid")] -pub struct InvalidAdapter; - #[derive(Clone, Debug, Error)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] #[non_exhaustive] @@ -869,73 +859,51 @@ impl Global { id } - pub fn adapter_get_info( - &self, - adapter_id: AdapterId, - ) -> Result { - self.hub - .adapters - .get(adapter_id) - .map(|adapter| adapter.raw.info.clone()) - .map_err(|_| InvalidAdapter) + pub fn adapter_get_info(&self, adapter_id: AdapterId) -> wgt::AdapterInfo { + let adapter = self.hub.adapters.strict_get(adapter_id); + adapter.raw.info.clone() } pub fn adapter_get_texture_format_features( &self, adapter_id: AdapterId, format: wgt::TextureFormat, - ) -> Result { - self.hub - .adapters - .get(adapter_id) - .map(|adapter| adapter.get_texture_format_features(format)) - .map_err(|_| InvalidAdapter) + ) -> wgt::TextureFormatFeatures { + let adapter = self.hub.adapters.strict_get(adapter_id); + adapter.get_texture_format_features(format) } - pub fn adapter_features(&self, adapter_id: AdapterId) -> Result { - self.hub - .adapters - .get(adapter_id) - .map(|adapter| adapter.raw.features) - .map_err(|_| InvalidAdapter) + pub fn adapter_features(&self, adapter_id: AdapterId) -> wgt::Features { + let adapter = self.hub.adapters.strict_get(adapter_id); + adapter.raw.features } - pub fn adapter_limits(&self, adapter_id: AdapterId) -> Result { - self.hub - .adapters - .get(adapter_id) - .map(|adapter| adapter.raw.capabilities.limits.clone()) - .map_err(|_| InvalidAdapter) + pub fn adapter_limits(&self, adapter_id: AdapterId) -> wgt::Limits { + let adapter = self.hub.adapters.strict_get(adapter_id); + adapter.raw.capabilities.limits.clone() } pub fn adapter_downlevel_capabilities( &self, adapter_id: AdapterId, - ) -> Result { - self.hub - .adapters - .get(adapter_id) - .map(|adapter| adapter.raw.capabilities.downlevel.clone()) - .map_err(|_| InvalidAdapter) + ) -> wgt::DownlevelCapabilities { + let adapter = self.hub.adapters.strict_get(adapter_id); + adapter.raw.capabilities.downlevel.clone() } pub fn adapter_get_presentation_timestamp( &self, adapter_id: AdapterId, - ) -> Result { - let hub = &self.hub; - - let adapter = hub.adapters.get(adapter_id).map_err(|_| InvalidAdapter)?; - - Ok(unsafe { adapter.raw.adapter.get_presentation_timestamp() }) + ) -> wgt::PresentationTimestamp { + let adapter = self.hub.adapters.strict_get(adapter_id); + unsafe { adapter.raw.adapter.get_presentation_timestamp() } } pub fn adapter_drop(&self, adapter_id: AdapterId) { profiling::scope!("Adapter::drop"); api_log!("Adapter::drop {adapter_id:?}"); - let hub = &self.hub; - hub.adapters.unregister(adapter_id); + self.hub.adapters.strict_unregister(adapter_id); } } @@ -956,10 +924,7 @@ impl Global { let queue_fid = self.hub.queues.prepare(backend, queue_id_in); let error = 'error: { - let adapter = match self.hub.adapters.get(adapter_id) { - Ok(adapter) => adapter, - Err(_) => break 'error RequestDeviceError::InvalidAdapter, - }; + let adapter = self.hub.adapters.strict_get(adapter_id); let (device, queue) = match adapter.create_device_and_queue(desc, self.instance.flags, trace_path) { Ok((device, queue)) => (device, queue), @@ -1000,10 +965,7 @@ impl Global { let queues_fid = self.hub.queues.prepare(backend, queue_id_in); let error = 'error: { - let adapter = match self.hub.adapters.get(adapter_id) { - Ok(adapter) => adapter, - Err(_) => break 'error RequestDeviceError::InvalidAdapter, - }; + let adapter = self.hub.adapters.strict_get(adapter_id); let (device, queue) = match adapter.create_device_and_queue_from_hal( hal_device, desc, diff --git a/wgpu-core/src/registry.rs b/wgpu-core/src/registry.rs index 82617d985..9e76f2fc3 100644 --- a/wgpu-core/src/registry.rs +++ b/wgpu-core/src/registry.rs @@ -120,6 +120,15 @@ impl Registry { //Returning None is legal if it's an error ID value } + pub(crate) fn strict_unregister(&self, id: Id) -> T { + let value = self.storage.write().strict_remove(id); + // This needs to happen *after* removing it from the storage, to maintain the + // invariant that `self.identity` only contains ids which are actually available + // See https://github.com/gfx-rs/wgpu/issues/5372 + self.identity.free(id); + //Returning None is legal if it's an error ID + value + } pub(crate) fn generate_report(&self) -> RegistryReport { let storage = self.storage.read(); @@ -143,6 +152,10 @@ impl Registry { pub(crate) fn get(&self, id: Id) -> Result { self.read().get_owned(id) } + + pub(crate) fn strict_get(&self, id: Id) -> T { + self.read().strict_get(id) + } } #[cfg(test)] diff --git a/wgpu-core/src/resource.rs b/wgpu-core/src/resource.rs index dc0792cb3..159ac8c89 100644 --- a/wgpu-core/src/resource.rs +++ b/wgpu-core/src/resource.rs @@ -1272,11 +1272,8 @@ impl Global { profiling::scope!("Adapter::as_hal"); let hub = &self.hub; - let adapter = hub.adapters.get(id).ok(); - let hal_adapter = adapter - .as_ref() - .map(|adapter| &adapter.raw.adapter) - .and_then(|adapter| adapter.as_any().downcast_ref()); + let adapter = hub.adapters.strict_get(id); + let hal_adapter = adapter.raw.adapter.as_any().downcast_ref(); hal_adapter_callback(hal_adapter) } diff --git a/wgpu-core/src/storage.rs b/wgpu-core/src/storage.rs index cdcf0ea16..e7ed8ff55 100644 --- a/wgpu-core/src/storage.rs +++ b/wgpu-core/src/storage.rs @@ -157,6 +157,18 @@ where } } + pub(crate) fn strict_remove(&mut self, id: Id) -> T { + let (index, epoch, _) = id.unzip(); + match std::mem::replace(&mut self.map[index as usize], Element::Vacant) { + Element::Occupied(value, storage_epoch) => { + assert_eq!(epoch, storage_epoch); + value + } + Element::Error(_) => unreachable!(), + Element::Vacant => panic!("Cannot remove a vacant resource"), + } + } + pub(crate) fn iter(&self, backend: Backend) -> impl Iterator, &T)> { self.map .iter() @@ -183,4 +195,21 @@ where pub(crate) fn get_owned(&self, id: Id) -> Result { Ok(self.get(id)?.clone()) } + + /// Get an owned reference to an item. + /// Panics if there is an epoch mismatch, the entry is empty or in error. + pub(crate) fn strict_get(&self, id: Id) -> T { + let (index, epoch, _) = id.unzip(); + let (result, storage_epoch) = match self.map.get(index as usize) { + Some(&Element::Occupied(ref v, epoch)) => (v.clone(), epoch), + None | Some(&Element::Vacant) => panic!("{}[{:?}] does not exist", self.kind, id), + Some(&Element::Error(_)) => unreachable!(), + }; + assert_eq!( + epoch, storage_epoch, + "{}[{:?}] is no longer alive", + self.kind, id + ); + result + } } diff --git a/wgpu/src/backend/wgpu_core.rs b/wgpu/src/backend/wgpu_core.rs index 9b2597cb4..ad8ccd645 100644 --- a/wgpu/src/backend/wgpu_core.rs +++ b/wgpu/src/backend/wgpu_core.rs @@ -651,34 +651,22 @@ impl crate::Context for ContextWgpuCore { } fn adapter_features(&self, adapter_data: &Self::AdapterData) -> Features { - match self.0.adapter_features(*adapter_data) { - Ok(features) => features, - Err(err) => self.handle_error_fatal(err, "Adapter::features"), - } + self.0.adapter_features(*adapter_data) } fn adapter_limits(&self, adapter_data: &Self::AdapterData) -> Limits { - match self.0.adapter_limits(*adapter_data) { - Ok(limits) => limits, - Err(err) => self.handle_error_fatal(err, "Adapter::limits"), - } + self.0.adapter_limits(*adapter_data) } fn adapter_downlevel_capabilities( &self, adapter_data: &Self::AdapterData, ) -> DownlevelCapabilities { - match self.0.adapter_downlevel_capabilities(*adapter_data) { - Ok(downlevel) => downlevel, - Err(err) => self.handle_error_fatal(err, "Adapter::downlevel_properties"), - } + self.0.adapter_downlevel_capabilities(*adapter_data) } fn adapter_get_info(&self, adapter_data: &Self::AdapterData) -> AdapterInfo { - match self.0.adapter_get_info(*adapter_data) { - Ok(info) => info, - Err(err) => self.handle_error_fatal(err, "Adapter::get_info"), - } + self.0.adapter_get_info(*adapter_data) } fn adapter_get_texture_format_features( @@ -686,23 +674,15 @@ impl crate::Context for ContextWgpuCore { adapter_data: &Self::AdapterData, format: wgt::TextureFormat, ) -> wgt::TextureFormatFeatures { - match self - .0 + self.0 .adapter_get_texture_format_features(*adapter_data, format) - { - Ok(info) => info, - Err(err) => self.handle_error_fatal(err, "Adapter::get_texture_format_features"), - } } fn adapter_get_presentation_timestamp( &self, adapter_data: &Self::AdapterData, ) -> wgt::PresentationTimestamp { - match self.0.adapter_get_presentation_timestamp(*adapter_data) { - Ok(timestamp) => timestamp, - Err(err) => self.handle_error_fatal(err, "Adapter::correlate_presentation_timestamp"), - } + self.0.adapter_get_presentation_timestamp(*adapter_data) } fn surface_get_capabilities(