mirror of
https://github.com/gfx-rs/wgpu.git
synced 2026-01-09 14:48:08 -05:00
Allow obtaining custom implementation from wgpu api types (#7541)
This commit is contained in:
@@ -3,9 +3,9 @@ use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
|
||||
use wgpu::custom::{
|
||||
AdapterInterface, DeviceInterface, DispatchAdapter, DispatchDevice, DispatchQueue,
|
||||
DispatchShaderModule, DispatchSurface, InstanceInterface, QueueInterface, RequestAdapterFuture,
|
||||
ShaderModuleInterface,
|
||||
AdapterInterface, ComputePipelineInterface, DeviceInterface, DispatchAdapter, DispatchDevice,
|
||||
DispatchQueue, DispatchShaderModule, DispatchSurface, InstanceInterface, QueueInterface,
|
||||
RequestAdapterFuture, ShaderModuleInterface,
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
@@ -163,9 +163,10 @@ impl DeviceInterface for CustomDevice {
|
||||
|
||||
fn create_compute_pipeline(
|
||||
&self,
|
||||
_desc: &wgpu::ComputePipelineDescriptor<'_>,
|
||||
desc: &wgpu::ComputePipelineDescriptor<'_>,
|
||||
) -> wgpu::custom::DispatchComputePipeline {
|
||||
unimplemented!()
|
||||
let module = desc.module.as_custom::<CustomShaderModule>().unwrap();
|
||||
wgpu::custom::DispatchComputePipeline::custom(CustomComputePipeline(module.0.clone()))
|
||||
}
|
||||
|
||||
unsafe fn create_pipeline_cache(
|
||||
@@ -265,7 +266,7 @@ impl DeviceInterface for CustomDevice {
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct CustomShaderModule(Counter);
|
||||
pub struct CustomShaderModule(pub Counter);
|
||||
|
||||
impl ShaderModuleInterface for CustomShaderModule {
|
||||
fn get_compilation_info(&self) -> Pin<Box<dyn wgpu::custom::ShaderCompilationInfoFuture>> {
|
||||
@@ -346,3 +347,12 @@ impl QueueInterface for CustomQueue {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct CustomComputePipeline(pub Counter);
|
||||
|
||||
impl ComputePipelineInterface for CustomComputePipeline {
|
||||
fn get_bind_group_layout(&self, _index: u32) -> wgpu::custom::DispatchBindGroupLayout {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use custom::Counter;
|
||||
use custom::{Counter, CustomShaderModule};
|
||||
use wgpu::{DeviceDescriptor, RequestAdapterOptions};
|
||||
|
||||
mod custom;
|
||||
@@ -31,12 +31,26 @@ async fn main() {
|
||||
.unwrap();
|
||||
assert_eq!(counter.count(), 5);
|
||||
|
||||
let _module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
|
||||
let module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
|
||||
label: Some("shader"),
|
||||
source: wgpu::ShaderSource::Dummy(PhantomData),
|
||||
});
|
||||
|
||||
let custom_module = module.as_custom::<CustomShaderModule>().unwrap();
|
||||
assert_eq!(custom_module.0.count(), 6);
|
||||
let _module_clone = module.clone();
|
||||
assert_eq!(counter.count(), 6);
|
||||
|
||||
let _pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
|
||||
label: None,
|
||||
layout: None,
|
||||
module: &module,
|
||||
entry_point: None,
|
||||
compilation_options: Default::default(),
|
||||
cache: None,
|
||||
});
|
||||
|
||||
assert_eq!(counter.count(), 7);
|
||||
}
|
||||
assert_eq!(counter.count(), 1);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user