Allow obtaining custom implementation from wgpu api types (#7541)

This commit is contained in:
sagudev
2025-04-18 22:58:49 +02:00
committed by GitHub
parent a9a3ea3a41
commit 6666d528b2
30 changed files with 262 additions and 17 deletions

View File

@@ -3,9 +3,9 @@ use std::pin::Pin;
use std::sync::Arc;
use wgpu::custom::{
AdapterInterface, DeviceInterface, DispatchAdapter, DispatchDevice, DispatchQueue,
DispatchShaderModule, DispatchSurface, InstanceInterface, QueueInterface, RequestAdapterFuture,
ShaderModuleInterface,
AdapterInterface, ComputePipelineInterface, DeviceInterface, DispatchAdapter, DispatchDevice,
DispatchQueue, DispatchShaderModule, DispatchSurface, InstanceInterface, QueueInterface,
RequestAdapterFuture, ShaderModuleInterface,
};
#[derive(Debug, Clone)]
@@ -163,9 +163,10 @@ impl DeviceInterface for CustomDevice {
fn create_compute_pipeline(
&self,
_desc: &wgpu::ComputePipelineDescriptor<'_>,
desc: &wgpu::ComputePipelineDescriptor<'_>,
) -> wgpu::custom::DispatchComputePipeline {
unimplemented!()
let module = desc.module.as_custom::<CustomShaderModule>().unwrap();
wgpu::custom::DispatchComputePipeline::custom(CustomComputePipeline(module.0.clone()))
}
unsafe fn create_pipeline_cache(
@@ -265,7 +266,7 @@ impl DeviceInterface for CustomDevice {
}
#[derive(Debug)]
struct CustomShaderModule(Counter);
pub struct CustomShaderModule(pub Counter);
impl ShaderModuleInterface for CustomShaderModule {
fn get_compilation_info(&self) -> Pin<Box<dyn wgpu::custom::ShaderCompilationInfoFuture>> {
@@ -346,3 +347,12 @@ impl QueueInterface for CustomQueue {
unimplemented!()
}
}
#[derive(Debug)]
pub struct CustomComputePipeline(pub Counter);
impl ComputePipelineInterface for CustomComputePipeline {
fn get_bind_group_layout(&self, _index: u32) -> wgpu::custom::DispatchBindGroupLayout {
unimplemented!()
}
}

View File

@@ -1,6 +1,6 @@
use std::marker::PhantomData;
use custom::Counter;
use custom::{Counter, CustomShaderModule};
use wgpu::{DeviceDescriptor, RequestAdapterOptions};
mod custom;
@@ -31,12 +31,26 @@ async fn main() {
.unwrap();
assert_eq!(counter.count(), 5);
let _module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
let module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("shader"),
source: wgpu::ShaderSource::Dummy(PhantomData),
});
let custom_module = module.as_custom::<CustomShaderModule>().unwrap();
assert_eq!(custom_module.0.count(), 6);
let _module_clone = module.clone();
assert_eq!(counter.count(), 6);
let _pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: None,
layout: None,
module: &module,
entry_point: None,
compilation_options: Default::default(),
cache: None,
});
assert_eq!(counter.count(), 7);
}
assert_eq!(counter.count(), 1);
}