Initial precompiled shaders implementation (#7834)

Co-authored-by: Connor Fitzgerald <connorwadefitzgerald@gmail.com>
This commit is contained in:
Magnus
2025-08-20 15:20:59 -05:00
committed by GitHub
parent eb9b2e9c9b
commit 17a17f716a
26 changed files with 383 additions and 315 deletions

View File

@@ -53,6 +53,27 @@ We have merged the acceleration structure feature into the `RayQuery` feature. T
By @Vecvec in [#7913](https://github.com/gfx-rs/wgpu/pull/7913).
#### New `EXPERIMENTAL_PRECOMPILED_SHADERS` API
We have added `Features::EXPERIMENTAL_PRECOMPILED_SHADERS`, replacing existing passthrough types with a unified `CreateShaderModuleDescriptorPassthrough` which allows passing multiple shader codes for different backends. By @SupaMaggie70Incorporated in [#7834](https://github.com/gfx-rs/wgpu/pull/7834)
Difference for SPIR-V passthrough:
```diff
- device.create_shader_module_passthrough(wgpu::ShaderModuleDescriptorPassthrough::SpirV(
- wgpu::ShaderModuleDescriptorSpirV {
- label: None,
- source: spirv_code,
- },
- ))
+ device.create_shader_module_passthrough(wgpu::ShaderModuleDescriptorPassthrough {
+ entry_point: "main".into(),
+ label: None,
+ spirv: Some(spirv_code),
+ ..Default::default()
})
```
This allows using precompiled shaders without manually checking which backend's code to pass, for example if you have shaders precompiled for both DXIL and SPIR-V.
### New Features
#### General

1
Cargo.lock generated
View File

@@ -3191,6 +3191,7 @@ checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c"
name = "player"
version = "26.0.0"
dependencies = [
"bytemuck",
"env_logger",
"log",
"raw-window-handle 0.6.2",

View File

@@ -419,10 +419,6 @@ pub enum GPUFeatureName {
VertexWritableStorage,
#[webidl(rename = "clear-texture")]
ClearTexture,
#[webidl(rename = "msl-shader-passthrough")]
MslShaderPassthrough,
#[webidl(rename = "spirv-shader-passthrough")]
SpirvShaderPassthrough,
#[webidl(rename = "multiview")]
Multiview,
#[webidl(rename = "vertex-attribute-64-bit")]
@@ -435,6 +431,8 @@ pub enum GPUFeatureName {
ShaderPrimitiveIndex,
#[webidl(rename = "shader-early-depth-test")]
ShaderEarlyDepthTest,
#[webidl(rename = "passthrough-shaders")]
PassthroughShaders,
}
pub fn feature_names_to_features(names: Vec<GPUFeatureName>) -> wgpu_types::Features {
@@ -482,14 +480,13 @@ pub fn feature_names_to_features(names: Vec<GPUFeatureName>) -> wgpu_types::Feat
GPUFeatureName::ConservativeRasterization => Features::CONSERVATIVE_RASTERIZATION,
GPUFeatureName::VertexWritableStorage => Features::VERTEX_WRITABLE_STORAGE,
GPUFeatureName::ClearTexture => Features::CLEAR_TEXTURE,
GPUFeatureName::MslShaderPassthrough => Features::MSL_SHADER_PASSTHROUGH,
GPUFeatureName::SpirvShaderPassthrough => Features::SPIRV_SHADER_PASSTHROUGH,
GPUFeatureName::Multiview => Features::MULTIVIEW,
GPUFeatureName::VertexAttribute64Bit => Features::VERTEX_ATTRIBUTE_64BIT,
GPUFeatureName::ShaderF64 => Features::SHADER_F64,
GPUFeatureName::ShaderI16 => Features::SHADER_I16,
GPUFeatureName::ShaderPrimitiveIndex => Features::SHADER_PRIMITIVE_INDEX,
GPUFeatureName::ShaderEarlyDepthTest => Features::SHADER_EARLY_DEPTH_TEST,
GPUFeatureName::PassthroughShaders => Features::EXPERIMENTAL_PASSTHROUGH_SHADERS,
};
features.set(feature, true);
}
@@ -626,9 +623,6 @@ pub fn features_to_feature_names(features: wgpu_types::Features) -> HashSet<GPUF
if features.contains(wgpu_types::Features::CLEAR_TEXTURE) {
return_features.insert(ClearTexture);
}
if features.contains(wgpu_types::Features::SPIRV_SHADER_PASSTHROUGH) {
return_features.insert(SpirvShaderPassthrough);
}
if features.contains(wgpu_types::Features::MULTIVIEW) {
return_features.insert(Multiview);
}
@@ -648,6 +642,9 @@ pub fn features_to_feature_names(features: wgpu_types::Features) -> HashSet<GPUF
if features.contains(wgpu_types::Features::SHADER_EARLY_DEPTH_TEST) {
return_features.insert(ShaderEarlyDepthTest);
}
if features.contains(wgpu_types::Features::EXPERIMENTAL_PASSTHROUGH_SHADERS) {
return_features.insert(PassthroughShaders);
}
return_features
}

View File

@@ -24,12 +24,12 @@ fn compile_glsl(
let output = cmd.wait_with_output().expect("Error waiting for glslc");
assert!(output.status.success());
unsafe {
device.create_shader_module_passthrough(wgpu::ShaderModuleDescriptorPassthrough::SpirV(
wgpu::ShaderModuleDescriptorSpirV {
label: None,
source: wgpu::util::make_spirv_raw(&output.stdout),
},
))
device.create_shader_module_passthrough(wgpu::ShaderModuleDescriptorPassthrough {
entry_point: "main".into(),
label: None,
spirv: Some(wgpu::util::make_spirv_raw(&output.stdout)),
..Default::default()
})
}
}
@@ -119,7 +119,7 @@ impl crate::framework::Example for Example {
Default::default()
}
fn required_features() -> wgpu::Features {
wgpu::Features::EXPERIMENTAL_MESH_SHADER | wgpu::Features::SPIRV_SHADER_PASSTHROUGH
wgpu::Features::EXPERIMENTAL_MESH_SHADER | wgpu::Features::EXPERIMENTAL_PASSTHROUGH_SHADERS
}
fn required_limits() -> wgpu::Limits {
wgpu::Limits::defaults().using_recommended_minimum_mesh_shader_values()

View File

@@ -26,6 +26,7 @@ log.workspace = true
raw-window-handle.workspace = true
ron.workspace = true
winit = { workspace = true, optional = true }
bytemuck.workspace = true
# Non-Webassembly
#

View File

@@ -315,6 +315,84 @@ impl GlobalPlay for wgc::global::Global {
println!("shader compilation error:\n---{code}\n---\n{e}");
}
}
Action::CreateShaderModulePassthrough {
id,
data,
entry_point,
label,
num_workgroups,
runtime_checks,
} => {
let spirv = data.iter().find_map(|a| {
if a.ends_with(".spv") {
let data = fs::read(dir.join(a)).unwrap();
assert!(data.len() % 4 == 0);
Some(Cow::Owned(bytemuck::pod_collect_to_vec(&data)))
} else {
None
}
});
let dxil = data.iter().find_map(|a| {
if a.ends_with(".dxil") {
let vec = std::fs::read(dir.join(a)).unwrap();
Some(Cow::Owned(vec))
} else {
None
}
});
let hlsl = data.iter().find_map(|a| {
if a.ends_with(".hlsl") {
let code = fs::read_to_string(dir.join(a)).unwrap();
Some(Cow::Owned(code))
} else {
None
}
});
let msl = data.iter().find_map(|a| {
if a.ends_with(".msl") {
let code = fs::read_to_string(dir.join(a)).unwrap();
Some(Cow::Owned(code))
} else {
None
}
});
let glsl = data.iter().find_map(|a| {
if a.ends_with(".glsl") {
let code = fs::read_to_string(dir.join(a)).unwrap();
Some(Cow::Owned(code))
} else {
None
}
});
let wgsl = data.iter().find_map(|a| {
if a.ends_with(".wgsl") {
let code = fs::read_to_string(dir.join(a)).unwrap();
Some(Cow::Owned(code))
} else {
None
}
});
let desc = wgt::CreateShaderModuleDescriptorPassthrough {
entry_point,
label,
num_workgroups,
runtime_checks,
spirv,
dxil,
hlsl,
msl,
glsl,
wgsl,
};
let (_, error) = unsafe {
self.device_create_shader_module_passthrough(device, &desc, Some(id))
};
if let Some(e) = error {
println!("shader compilation error: {e}");
}
}
Action::DestroyShaderModule(id) => {
self.shader_module_drop(id);
}

View File

@@ -41,12 +41,12 @@ fn compile_glsl(
let output = cmd.wait_with_output().expect("Error waiting for glslc");
assert!(output.status.success());
unsafe {
device.create_shader_module_passthrough(wgpu::ShaderModuleDescriptorPassthrough::SpirV(
wgpu::ShaderModuleDescriptorSpirV {
label: None,
source: wgpu::util::make_spirv_raw(&output.stdout),
},
))
device.create_shader_module_passthrough(wgpu::ShaderModuleDescriptorPassthrough {
entry_point: "main".into(),
label: None,
spirv: Some(wgpu::util::make_spirv_raw(&output.stdout)),
..Default::default()
})
}
}
@@ -267,7 +267,7 @@ fn default_gpu_test_config(draw_type: DrawType) -> GpuTestConfiguration {
.test_features_limits()
.features(
wgpu::Features::EXPERIMENTAL_MESH_SHADER
| wgpu::Features::SPIRV_SHADER_PASSTHROUGH
| wgpu::Features::EXPERIMENTAL_PASSTHROUGH_SHADERS
| match draw_type {
DrawType::Standard | DrawType::Indirect => wgpu::Features::empty(),
DrawType::MultiIndirect => wgpu::Features::MULTI_DRAW_INDIRECT,

View File

@@ -1094,36 +1094,27 @@ impl Global {
#[cfg(feature = "trace")]
if let Some(ref mut trace) = *device.trace.lock() {
let data = trace.make_binary(desc.trace_binary_ext(), desc.trace_data());
trace.add(trace::Action::CreateShaderModule {
let mut file_names = Vec::new();
for (data, ext) in [
(desc.spirv.as_ref().map(|a| bytemuck::cast_slice(a)), "spv"),
(desc.dxil.as_deref(), "dxil"),
(desc.hlsl.as_ref().map(|a| a.as_bytes()), "hlsl"),
(desc.msl.as_ref().map(|a| a.as_bytes()), "msl"),
(desc.glsl.as_ref().map(|a| a.as_bytes()), "glsl"),
(desc.wgsl.as_ref().map(|a| a.as_bytes()), "wgsl"),
] {
if let Some(data) = data {
file_names.push(trace.make_binary(ext, data));
}
}
trace.add(trace::Action::CreateShaderModulePassthrough {
id: fid.id(),
desc: match desc {
pipeline::ShaderModuleDescriptorPassthrough::SpirV(inner) => {
pipeline::ShaderModuleDescriptor {
label: inner.label.clone(),
runtime_checks: wgt::ShaderRuntimeChecks::unchecked(),
}
}
pipeline::ShaderModuleDescriptorPassthrough::Msl(inner) => {
pipeline::ShaderModuleDescriptor {
label: inner.label.clone(),
runtime_checks: wgt::ShaderRuntimeChecks::unchecked(),
}
}
pipeline::ShaderModuleDescriptorPassthrough::Dxil(inner) => {
pipeline::ShaderModuleDescriptor {
label: inner.label.clone(),
runtime_checks: wgt::ShaderRuntimeChecks::unchecked(),
}
}
pipeline::ShaderModuleDescriptorPassthrough::Hlsl(inner) => {
pipeline::ShaderModuleDescriptor {
label: inner.label.clone(),
runtime_checks: wgt::ShaderRuntimeChecks::unchecked(),
}
}
},
data,
data: file_names,
entry_point: desc.entry_point.clone(),
label: desc.label.clone(),
num_workgroups: desc.num_workgroups,
runtime_checks: desc.runtime_checks,
});
};
@@ -1138,7 +1129,7 @@ impl Global {
return (id, None);
};
let id = fid.assign(Fallible::Invalid(Arc::new(desc.label().to_string())));
let id = fid.assign(Fallible::Invalid(Arc::new(desc.label.to_string())));
(id, Some(error))
}

View File

@@ -2125,39 +2125,59 @@ impl Device {
descriptor: &pipeline::ShaderModuleDescriptorPassthrough<'a>,
) -> Result<Arc<pipeline::ShaderModule>, pipeline::CreateShaderModuleError> {
self.check_is_valid()?;
let hal_shader = match descriptor {
pipeline::ShaderModuleDescriptorPassthrough::SpirV(inner) => {
self.require_features(wgt::Features::SPIRV_SHADER_PASSTHROUGH)?;
hal::ShaderInput::SpirV(&inner.source)
}
pipeline::ShaderModuleDescriptorPassthrough::Msl(inner) => {
self.require_features(wgt::Features::MSL_SHADER_PASSTHROUGH)?;
hal::ShaderInput::Msl {
shader: inner.source.to_string(),
entry_point: inner.entry_point.to_string(),
num_workgroups: inner.num_workgroups,
self.require_features(wgt::Features::EXPERIMENTAL_PASSTHROUGH_SHADERS)?;
// TODO: when we get to use if-let chains, this will be a little nicer!
log::info!("Backend: {}", self.backend());
let hal_shader = match self.backend() {
wgt::Backend::Vulkan => hal::ShaderInput::SpirV(
descriptor
.spirv
.as_ref()
.ok_or(pipeline::CreateShaderModuleError::NotCompiledForBackend)?,
),
wgt::Backend::Dx12 => {
if let Some(dxil) = &descriptor.dxil {
hal::ShaderInput::Dxil {
shader: dxil,
entry_point: descriptor.entry_point.clone(),
num_workgroups: descriptor.num_workgroups,
}
} else if let Some(hlsl) = &descriptor.hlsl {
hal::ShaderInput::Hlsl {
shader: hlsl,
entry_point: descriptor.entry_point.clone(),
num_workgroups: descriptor.num_workgroups,
}
} else {
return Err(pipeline::CreateShaderModuleError::NotCompiledForBackend);
}
}
pipeline::ShaderModuleDescriptorPassthrough::Dxil(inner) => {
self.require_features(wgt::Features::HLSL_DXIL_SHADER_PASSTHROUGH)?;
hal::ShaderInput::Dxil {
shader: inner.source,
entry_point: inner.entry_point.clone(),
num_workgroups: inner.num_workgroups,
}
}
pipeline::ShaderModuleDescriptorPassthrough::Hlsl(inner) => {
self.require_features(wgt::Features::HLSL_DXIL_SHADER_PASSTHROUGH)?;
hal::ShaderInput::Hlsl {
shader: inner.source,
entry_point: inner.entry_point.clone(),
num_workgroups: inner.num_workgroups,
}
wgt::Backend::Metal => hal::ShaderInput::Msl {
shader: descriptor
.msl
.as_ref()
.ok_or(pipeline::CreateShaderModuleError::NotCompiledForBackend)?,
entry_point: descriptor.entry_point.clone(),
num_workgroups: descriptor.num_workgroups,
},
wgt::Backend::Gl => hal::ShaderInput::Glsl {
shader: descriptor
.glsl
.as_ref()
.ok_or(pipeline::CreateShaderModuleError::NotCompiledForBackend)?,
entry_point: descriptor.entry_point.clone(),
num_workgroups: descriptor.num_workgroups,
},
wgt::Backend::Noop => {
return Err(pipeline::CreateShaderModuleError::NotCompiledForBackend)
}
wgt::Backend::BrowserWebGpu => unreachable!(),
};
let hal_desc = hal::ShaderModuleDescriptor {
label: descriptor.label().to_hal(self.instance_flags),
label: descriptor.label.to_hal(self.instance_flags),
runtime_checks: wgt::ShaderRuntimeChecks::unchecked(),
};
@@ -2180,7 +2200,7 @@ impl Device {
raw: ManuallyDrop::new(raw),
device: self.clone(),
interface: None,
label: descriptor.label().to_string(),
label: descriptor.label.to_string(),
};
Ok(Arc::new(module))

View File

@@ -93,6 +93,15 @@ pub enum Action<'a> {
desc: crate::pipeline::ShaderModuleDescriptor<'a>,
data: FileName,
},
CreateShaderModulePassthrough {
id: id::ShaderModuleId,
data: Vec<FileName>,
entry_point: String,
label: crate::Label<'a>,
num_workgroups: (u32, u32, u32),
runtime_checks: wgt::ShaderRuntimeChecks,
},
DestroyShaderModule(id::ShaderModuleId),
CreateComputePipeline {
id: id::ComputePipelineId,

View File

@@ -130,6 +130,8 @@ pub enum CreateShaderModuleError {
group: u32,
limit: u32,
},
#[error("Generic shader passthrough does not contain any code compatible with this backend.")]
NotCompiledForBackend,
}
impl WebGpuError for CreateShaderModuleError {
@@ -147,6 +149,7 @@ impl WebGpuError for CreateShaderModuleError {
Self::ParsingGlsl(..) => return ErrorType::Validation,
#[cfg(feature = "spirv")]
Self::ParsingSpirV(..) => return ErrorType::Validation,
Self::NotCompiledForBackend => return ErrorType::Validation,
};
e.webgpu_error_type()
}

View File

@@ -363,8 +363,8 @@ impl super::Adapter {
| wgt::Features::TEXTURE_FORMAT_NV12
| wgt::Features::FLOAT32_FILTERABLE
| wgt::Features::TEXTURE_ATOMIC
| wgt::Features::EXTERNAL_TEXTURE
| wgt::Features::HLSL_DXIL_SHADER_PASSTHROUGH;
| wgt::Features::EXPERIMENTAL_PASSTHROUGH_SHADERS
| wgt::Features::EXTERNAL_TEXTURE;
//TODO: in order to expose this, we need to run a compute shader
// that extract the necessary statistics out of the D3D12 result.

View File

@@ -1777,12 +1777,6 @@ impl crate::Device for super::Device {
raw_name,
runtime_checks: desc.runtime_checks,
}),
crate::ShaderInput::SpirV(_) => {
panic!("SPIRV_SHADER_PASSTHROUGH is not enabled for this backend")
}
crate::ShaderInput::Msl { .. } => {
panic!("MSL_SHADER_PASSTHROUGH is not enabled for this backend")
}
crate::ShaderInput::Dxil {
shader,
entry_point,
@@ -1809,6 +1803,11 @@ impl crate::Device for super::Device {
raw_name,
runtime_checks: desc.runtime_checks,
}),
crate::ShaderInput::SpirV(_)
| crate::ShaderInput::Msl { .. }
| crate::ShaderInput::Glsl { .. } => {
unreachable!()
}
}
}
unsafe fn destroy_shader_module(&self, _module: super::ShaderModule) {

View File

@@ -222,8 +222,8 @@ impl super::Device {
};
let (module, info) = naga::back::pipeline_constants::process_overrides(
&stage.module.naga.module,
&stage.module.naga.info,
&stage.module.source.module,
&stage.module.source.info,
Some((naga_stage, stage.entry_point)),
stage.constants,
)
@@ -463,7 +463,7 @@ impl super::Device {
for (stage_idx, stage_items) in push_constant_items.into_iter().enumerate() {
for item in stage_items {
let naga_module = &shaders[stage_idx].1.module.naga.module;
let naga_module = &shaders[stage_idx].1.module.source.module;
let type_inner = &naga_module.types[item.ty].inner;
let location = unsafe { gl.get_uniform_location(program, &item.access_path) };
@@ -1334,16 +1334,15 @@ impl crate::Device for super::Device {
self.counters.shader_modules.add(1);
Ok(super::ShaderModule {
naga: match shader {
crate::ShaderInput::SpirV(_) => {
panic!("`Features::SPIRV_SHADER_PASSTHROUGH` is not enabled")
}
crate::ShaderInput::Msl { .. } => {
panic!("`Features::MSL_SHADER_PASSTHROUGH` is not enabled")
}
source: match shader {
crate::ShaderInput::Naga(naga) => naga,
crate::ShaderInput::Dxil { .. } | crate::ShaderInput::Hlsl { .. } => {
panic!("`Features::HLSL_DXIL_SHADER_PASSTHROUGH` is not enabled")
// The backend doesn't yet expose this feature so it should be fine
crate::ShaderInput::Glsl { .. } => unimplemented!(),
crate::ShaderInput::SpirV(_)
| crate::ShaderInput::Msl { .. }
| crate::ShaderInput::Dxil { .. }
| crate::ShaderInput::Hlsl { .. } => {
unreachable!()
}
},
label: desc.label.map(|str| str.to_string()),

View File

@@ -605,7 +605,7 @@ type ShaderId = u32;
#[derive(Debug)]
pub struct ShaderModule {
naga: crate::NagaShader,
source: crate::NagaShader,
label: Option<String>,
id: ShaderId,
}

View File

@@ -2219,7 +2219,7 @@ impl fmt::Debug for NagaShader {
pub enum ShaderInput<'a> {
Naga(NagaShader),
Msl {
shader: String,
shader: &'a str,
entry_point: String,
num_workgroups: (u32, u32, u32),
},
@@ -2234,6 +2234,11 @@ pub enum ShaderInput<'a> {
entry_point: String,
num_workgroups: (u32, u32, u32),
},
Glsl {
shader: &'a str,
entry_point: String,
num_workgroups: (u32, u32, u32),
},
}
pub struct ShaderModuleDescriptor<'a> {

View File

@@ -917,7 +917,6 @@ impl super::PrivateCapabilities {
use wgt::Features as F;
let mut features = F::empty()
| F::MSL_SHADER_PASSTHROUGH
| F::MAPPABLE_PRIMARY_BUFFERS
| F::VERTEX_WRITABLE_STORAGE
| F::TEXTURE_ADAPTER_SPECIFIC_FORMAT_FEATURES
@@ -927,7 +926,8 @@ impl super::PrivateCapabilities {
| F::TEXTURE_FORMAT_16BIT_NORM
| F::SHADER_F16
| F::DEPTH32FLOAT_STENCIL8
| F::BGRA8UNORM_STORAGE;
| F::BGRA8UNORM_STORAGE
| F::EXPERIMENTAL_PASSTHROUGH_SHADERS;
features.set(F::FLOAT32_FILTERABLE, self.supports_float_filtering);
features.set(

View File

@@ -1017,7 +1017,7 @@ impl crate::Device for super::Device {
// Obtain the locked device from shared
let device = self.shared.device.lock();
let library = device
.new_library_with_source(&source, &options)
.new_library_with_source(source, &options)
.map_err(|e| crate::ShaderError::Compilation(format!("MSL: {e:?}")))?;
let function = library.get_function(&entry_point, None).map_err(|_| {
crate::ShaderError::Compilation(format!(
@@ -1035,12 +1035,10 @@ impl crate::Device for super::Device {
bounds_checks: desc.runtime_checks,
})
}
crate::ShaderInput::SpirV(_) => {
panic!("SPIRV_SHADER_PASSTHROUGH is not enabled for this backend")
}
crate::ShaderInput::Dxil { .. } | crate::ShaderInput::Hlsl { .. } => {
panic!("`Features::HLSL_DXIL_SHADER_PASSTHROUGH` is not enabled for this backend")
}
crate::ShaderInput::SpirV(_)
| crate::ShaderInput::Dxil { .. }
| crate::ShaderInput::Hlsl { .. }
| crate::ShaderInput::Glsl { .. } => unreachable!(),
}
}

View File

@@ -543,7 +543,6 @@ impl PhysicalDeviceFeatures {
) -> (wgt::Features, wgt::DownlevelFlags) {
use wgt::{DownlevelFlags as Df, Features as F};
let mut features = F::empty()
| F::SPIRV_SHADER_PASSTHROUGH
| F::MAPPABLE_PRIMARY_BUFFERS
| F::PUSH_CONSTANTS
| F::ADDRESS_MODE_CLAMP_TO_BORDER
@@ -555,7 +554,8 @@ impl PhysicalDeviceFeatures {
| F::CLEAR_TEXTURE
| F::PIPELINE_CACHE
| F::SHADER_EARLY_DEPTH_TEST
| F::TEXTURE_ATOMIC;
| F::TEXTURE_ATOMIC
| F::EXPERIMENTAL_PASSTHROUGH_SHADERS;
let mut dl_flags = Df::COMPUTE_SHADERS
| Df::BASE_VERTEX

View File

@@ -1940,13 +1940,11 @@ impl crate::Device for super::Device {
.map_err(|e| crate::ShaderError::Compilation(format!("{e}")))?,
)
}
crate::ShaderInput::Msl { .. } => {
panic!("MSL_SHADER_PASSTHROUGH is not enabled for this backend")
}
crate::ShaderInput::Dxil { .. } | crate::ShaderInput::Hlsl { .. } => {
panic!("`Features::HLSL_DXIL_SHADER_PASSTHROUGH` is not enabled")
}
crate::ShaderInput::SpirV(spv) => Cow::Borrowed(spv),
crate::ShaderInput::SpirV(data) => Cow::Borrowed(data),
crate::ShaderInput::Msl { .. }
| crate::ShaderInput::Dxil { .. }
| crate::ShaderInput::Hlsl { .. }
| crate::ShaderInput::Glsl { .. } => unreachable!(),
};
let raw = self.create_shader_module_impl(&spv)?;

View File

@@ -926,29 +926,6 @@ bitflags_array! {
///
/// This is a native only feature.
const CLEAR_TEXTURE = 1 << 23;
/// Enables creating shader modules from Metal MSL computer shaders (unsafe).
///
/// Metal data is not parsed or interpreted in any way
///
/// Supported platforms:
/// - Metal
///
/// This is a native only feature.
const MSL_SHADER_PASSTHROUGH = 1 << 24;
/// Enables creating shader modules from SPIR-V binary data (unsafe).
///
/// SPIR-V data is not parsed or interpreted in any way; you can use
/// [`wgpu::make_spirv_raw!`] to check for alignment and magic number when converting from
/// raw bytes.
///
/// Supported platforms:
/// - Vulkan, in case shader's requested capabilities and extensions agree with
/// Vulkan implementation.
///
/// This is a native only feature.
///
/// [`wgpu::make_spirv_raw!`]: https://docs.rs/wgpu/latest/wgpu/macro.include_spirv_raw.html
const SPIRV_SHADER_PASSTHROUGH = 1 << 25;
/// Enables multiview render passes and `builtin(view_index)` in vertex shaders.
///
/// Supported platforms:
@@ -1243,15 +1220,23 @@ bitflags_array! {
/// [BlasTriangleGeometrySizeDescriptor::vertex_format]: super::BlasTriangleGeometrySizeDescriptor
const EXTENDED_ACCELERATION_STRUCTURE_VERTEX_FORMATS = 1 << 51;
/// Enables creating shader modules from DirectX HLSL or DXIL shaders (unsafe)
/// Enables creating shaders from passthrough with reflection info (unsafe)
///
/// HLSL/DXIL data is not parsed or interpreted in any way
/// Allows using [`Device::create_shader_module_passthrough`].
/// Shader code isn't parsed or interpreted in any way. It is the user's
/// responsibility to ensure the code and reflection (if passed) are correct.
///
/// Supported platforms:
/// Supported platforms
/// - Vulkan
/// - DX12
/// - Metal
/// - WebGPU
///
/// This is a native only feature.
const HLSL_DXIL_SHADER_PASSTHROUGH = 1 << 52;
/// Ideally, in the future, all platforms will be supported. For more info, see
/// [this comment](https://github.com/gfx-rs/wgpu/issues/3103#issuecomment-2833058367).
///
/// [`Device::create_shader_module_passthrough`]: https://docs.rs/wgpu/latest/wgpu/struct.Device.html#method.create_shader_module_passthrough
const EXPERIMENTAL_PASSTHROUGH_SHADERS = 1 << 52;
}
/// Features that are not guaranteed to be supported.

View File

@@ -8032,20 +8032,52 @@ pub enum DeviceLostReason {
Destroyed = 1,
}
/// Descriptor for creating a shader module.
///
/// This type is unique to the Rust API of `wgpu`. In the WebGPU specification,
/// only WGSL source code strings are accepted.
/// Descriptor for a shader module given by any of several sources.
/// These shaders are passed through directly to the underlying api.
/// At least one shader type that may be used by the backend must be `Some` or a panic is raised.
#[derive(Debug, Clone)]
pub enum CreateShaderModuleDescriptorPassthrough<'a, L> {
/// Passthrough for SPIR-V binaries.
SpirV(ShaderModuleDescriptorSpirV<'a, L>),
/// Passthrough for MSL source code.
Msl(ShaderModuleDescriptorMsl<'a, L>),
/// Passthrough for DXIL compiled with DXC
Dxil(ShaderModuleDescriptorDxil<'a, L>),
/// Passthrough for HLSL
Hlsl(ShaderModuleDescriptorHlsl<'a, L>),
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct CreateShaderModuleDescriptorPassthrough<'a, L> {
/// Entrypoint. Unused for Spir-V.
pub entry_point: String,
/// Debug label of the shader module. This will show up in graphics debuggers for easy identification.
pub label: L,
/// Number of workgroups in each dimension x, y and z. Unused for Spir-V.
pub num_workgroups: (u32, u32, u32),
/// Runtime checks that should be enabled.
pub runtime_checks: ShaderRuntimeChecks,
/// Binary SPIR-V data, in 4-byte words.
pub spirv: Option<Cow<'a, [u32]>>,
/// Shader DXIL source.
pub dxil: Option<Cow<'a, [u8]>>,
/// Shader MSL source.
pub msl: Option<Cow<'a, str>>,
/// Shader HLSL source.
pub hlsl: Option<Cow<'a, str>>,
/// Shader GLSL source (currently unused).
pub glsl: Option<Cow<'a, str>>,
/// Shader WGSL source.
pub wgsl: Option<Cow<'a, str>>,
}
// This is so people don't have to fill in fields they don't use, like num_workgroups,
// entry_point, or other shader languages they didn't compile for
impl<'a, L: Default> Default for CreateShaderModuleDescriptorPassthrough<'a, L> {
fn default() -> Self {
Self {
entry_point: "".into(),
label: Default::default(),
num_workgroups: (0, 0, 0),
runtime_checks: ShaderRuntimeChecks::unchecked(),
spirv: None,
dxil: None,
msl: None,
hlsl: None,
glsl: None,
wgsl: None,
}
}
}
impl<'a, L> CreateShaderModuleDescriptorPassthrough<'a, L> {
@@ -8053,134 +8085,46 @@ impl<'a, L> CreateShaderModuleDescriptorPassthrough<'a, L> {
pub fn map_label<K>(
&self,
fun: impl FnOnce(&L) -> K,
) -> CreateShaderModuleDescriptorPassthrough<'_, K> {
match self {
CreateShaderModuleDescriptorPassthrough::SpirV(inner) => {
CreateShaderModuleDescriptorPassthrough::<'_, K>::SpirV(
ShaderModuleDescriptorSpirV {
label: fun(&inner.label),
source: inner.source.clone(),
},
)
}
CreateShaderModuleDescriptorPassthrough::Msl(inner) => {
CreateShaderModuleDescriptorPassthrough::<'_, K>::Msl(ShaderModuleDescriptorMsl {
entry_point: inner.entry_point.clone(),
label: fun(&inner.label),
num_workgroups: inner.num_workgroups,
source: inner.source.clone(),
})
}
CreateShaderModuleDescriptorPassthrough::Dxil(inner) => {
CreateShaderModuleDescriptorPassthrough::<'_, K>::Dxil(ShaderModuleDescriptorDxil {
entry_point: inner.entry_point.clone(),
label: fun(&inner.label),
num_workgroups: inner.num_workgroups,
source: inner.source,
})
}
CreateShaderModuleDescriptorPassthrough::Hlsl(inner) => {
CreateShaderModuleDescriptorPassthrough::<'_, K>::Hlsl(ShaderModuleDescriptorHlsl {
entry_point: inner.entry_point.clone(),
label: fun(&inner.label),
num_workgroups: inner.num_workgroups,
source: inner.source,
})
}
}
}
/// Returns the label of shader module passthrough descriptor.
pub fn label(&'a self) -> &'a L {
match self {
CreateShaderModuleDescriptorPassthrough::SpirV(inner) => &inner.label,
CreateShaderModuleDescriptorPassthrough::Msl(inner) => &inner.label,
CreateShaderModuleDescriptorPassthrough::Dxil(inner) => &inner.label,
CreateShaderModuleDescriptorPassthrough::Hlsl(inner) => &inner.label,
) -> CreateShaderModuleDescriptorPassthrough<'a, K> {
CreateShaderModuleDescriptorPassthrough {
entry_point: self.entry_point.clone(),
label: fun(&self.label),
num_workgroups: self.num_workgroups,
runtime_checks: self.runtime_checks,
spirv: self.spirv.clone(),
dxil: self.dxil.clone(),
msl: self.msl.clone(),
hlsl: self.hlsl.clone(),
glsl: self.glsl.clone(),
wgsl: self.wgsl.clone(),
}
}
#[cfg(feature = "trace")]
/// Returns the source data for tracing purpose.
pub fn trace_data(&self) -> &[u8] {
match self {
CreateShaderModuleDescriptorPassthrough::SpirV(inner) => {
bytemuck::cast_slice(&inner.source)
}
CreateShaderModuleDescriptorPassthrough::Msl(inner) => inner.source.as_bytes(),
CreateShaderModuleDescriptorPassthrough::Dxil(inner) => inner.source,
CreateShaderModuleDescriptorPassthrough::Hlsl(inner) => inner.source.as_bytes(),
if let Some(spirv) = &self.spirv {
bytemuck::cast_slice(spirv)
} else if let Some(msl) = &self.msl {
msl.as_bytes()
} else if let Some(dxil) = &self.dxil {
dxil
} else {
panic!("No binary data provided to `ShaderModuleDescriptorGeneric`")
}
}
#[cfg(feature = "trace")]
/// Returns the binary file extension for tracing purpose.
pub fn trace_binary_ext(&self) -> &'static str {
match self {
CreateShaderModuleDescriptorPassthrough::SpirV(..) => "spv",
CreateShaderModuleDescriptorPassthrough::Msl(..) => "msl",
CreateShaderModuleDescriptorPassthrough::Dxil(..) => "dxil",
CreateShaderModuleDescriptorPassthrough::Hlsl(..) => "hlsl",
if self.spirv.is_some() {
"spv"
} else if self.msl.is_some() {
"msl"
} else if self.dxil.is_some() {
"dxil"
} else {
panic!("No binary data provided to `ShaderModuleDescriptorGeneric`")
}
}
}
/// Descriptor for a shader module given by Metal MSL source.
///
/// This type is unique to the Rust API of `wgpu`. In the WebGPU specification,
/// only WGSL source code strings are accepted.
#[derive(Debug, Clone)]
pub struct ShaderModuleDescriptorMsl<'a, L> {
/// Entrypoint.
pub entry_point: String,
/// Debug label of the shader module. This will show up in graphics debuggers for easy identification.
pub label: L,
/// Number of workgroups in each dimension x, y and z.
pub num_workgroups: (u32, u32, u32),
/// Shader MSL source.
pub source: Cow<'a, str>,
}
/// Descriptor for a shader module given by DirectX DXIL source.
///
/// This type is unique to the Rust API of `wgpu`. In the WebGPU specification,
/// only WGSL source code strings are accepted.
#[derive(Debug, Clone)]
pub struct ShaderModuleDescriptorDxil<'a, L> {
/// Entrypoint.
pub entry_point: String,
/// Debug label of the shader module. This will show up in graphics debuggers for easy identification.
pub label: L,
/// Number of workgroups in each dimension x, y and z.
pub num_workgroups: (u32, u32, u32),
/// Shader DXIL source.
pub source: &'a [u8],
}
/// Descriptor for a shader module given by DirectX HLSL source.
///
/// This type is unique to the Rust API of `wgpu`. In the WebGPU specification,
/// only WGSL source code strings are accepted.
#[derive(Debug, Clone)]
pub struct ShaderModuleDescriptorHlsl<'a, L> {
/// Entrypoint.
pub entry_point: String,
/// Debug label of the shader module. This will show up in graphics debuggers for easy identification.
pub label: L,
/// Number of workgroups in each dimension x, y and z.
pub num_workgroups: (u32, u32, u32),
/// Shader HLSL source.
pub source: &'a str,
}
/// Descriptor for a shader module given by SPIR-V binary.
///
/// This type is unique to the Rust API of `wgpu`. In the WebGPU specification,
/// only WGSL source code strings are accepted.
#[derive(Debug, Clone)]
pub struct ShaderModuleDescriptorSpirV<'a, L> {
/// Debug label of the shader module. This will show up in graphics debuggers for easy identification.
pub label: L,
/// Binary SPIR-V data, in 4-byte words.
pub source: Cow<'a, [u32]>,
}

View File

@@ -228,34 +228,10 @@ pub struct ShaderModuleDescriptor<'a> {
}
static_assertions::assert_impl_all!(ShaderModuleDescriptor<'_>: Send, Sync);
/// Descriptor for a shader module that will bypass wgpu's shader tooling, for use with
/// [`Device::create_shader_module_passthrough`].
/// Descriptor for a shader module given by any of several sources.
/// At least one of the shader types that may be used by the backend must be `Some`
///
/// This type is unique to the Rust API of `wgpu`. In the WebGPU specification,
/// only WGSL source code strings are accepted.
pub type ShaderModuleDescriptorPassthrough<'a> =
wgt::CreateShaderModuleDescriptorPassthrough<'a, Label<'a>>;
/// Descriptor for a shader module given by Metal MSL source.
///
/// This type is unique to the Rust API of `wgpu`. In the WebGPU specification,
/// only WGSL source code strings are accepted.
pub type ShaderModuleDescriptorMsl<'a> = wgt::ShaderModuleDescriptorMsl<'a, Label<'a>>;
/// Descriptor for a shader module given by SPIR-V binary.
///
/// This type is unique to the Rust API of `wgpu`. In the WebGPU specification,
/// only WGSL source code strings are accepted.
pub type ShaderModuleDescriptorSpirV<'a> = wgt::ShaderModuleDescriptorSpirV<'a, Label<'a>>;
/// Descriptor for a shader module given by DirectX HLSL source.
///
/// This type is unique to the Rust API of `wgpu`. In the WebGPU specification,
/// only WGSL source code strings are accepted.
pub type ShaderModuleDescriptorHlsl<'a> = wgt::ShaderModuleDescriptorHlsl<'a, Label<'a>>;
/// Descriptor for a shader module given by DirectX DXIL source.
///
/// This type is unique to the Rust API of `wgpu`. In the WebGPU specification,
/// only WGSL source code strings are accepted.
pub type ShaderModuleDescriptorDxil<'a> = wgt::ShaderModuleDescriptorDxil<'a, Label<'a>>;

View File

@@ -1862,9 +1862,43 @@ impl dispatch::DeviceInterface for WebDevice {
unsafe fn create_shader_module_passthrough(
&self,
_desc: &crate::ShaderModuleDescriptorPassthrough<'_>,
desc: &crate::ShaderModuleDescriptorPassthrough<'_>,
) -> dispatch::DispatchShaderModule {
unreachable!("No XXX_SHADER_PASSTHROUGH feature enabled for this backend")
let shader_module_result = if let Some(ref code) = desc.wgsl {
let shader_module = webgpu_sys::GpuShaderModuleDescriptor::new(code);
Ok((
shader_module,
WebShaderCompilationInfo::Wgsl {
source: code.to_string(),
},
))
} else {
Err(crate::CompilationInfo {
messages: vec![crate::CompilationMessage {
message:
"Passthrough shader not compiled for WGSL on WebGPU backend (WGPU error)"
.to_string(),
location: None,
message_type: crate::CompilationMessageType::Error,
}],
})
};
let (descriptor, compilation_info) = match shader_module_result {
Ok(v) => v,
Err(compilation_info) => (
webgpu_sys::GpuShaderModuleDescriptor::new(""),
WebShaderCompilationInfo::Transformed { compilation_info },
),
};
if let Some(label) = desc.label {
descriptor.set_label(label);
}
WebShaderModule {
module: self.inner.create_shader_module(&descriptor),
compilation_info,
ident: crate::cmp::Identifier::create(),
}
.into()
}
fn create_bind_group_layout(

View File

@@ -1088,7 +1088,7 @@ impl dispatch::DeviceInterface for CoreDevice {
self.context.handle_error(
&self.error_sink,
cause.clone(),
desc.label().as_deref(),
desc.label.as_deref(),
"Device::create_shader_module_passthrough",
);
CompilationInfo::from(cause)

View File

@@ -96,22 +96,31 @@ macro_rules! include_spirv {
};
}
/// Macro to load raw SPIR-V data statically, for use with [`Features::SPIRV_SHADER_PASSTHROUGH`].
/// Macro to load raw SPIR-V data statically, for use with [`Features::EXPERIMENTAL_PASSTHROUGH_SHADERS`].
///
/// It ensures the word alignment as well as the magic number.
///
/// [`Features::SPIRV_SHADER_PASSTHROUGH`]: crate::Features::SPIRV_SHADER_PASSTHROUGH
/// [`Features::EXPERIMENTAL_PASSTHROUGH_SHADERS`]: crate::Features::EXPERIMENTAL_PASSTHROUGH_SHADERS
#[macro_export]
macro_rules! include_spirv_raw {
($($token:tt)*) => {
{
//log::info!("including '{}'", $($token)*);
$crate::ShaderModuleDescriptorPassthrough::SpirV(
$crate::ShaderModuleDescriptorSpirV {
label: $crate::__macro_helpers::Some($($token)*),
source: $crate::util::make_spirv_raw($crate::__macro_helpers::include_bytes!($($token)*)),
}
)
$crate::ShaderModuleDescriptorPassthrough {
label: $crate::__macro_helpers::Some($($token)*),
spirv: Some($crate::util::make_spirv_raw($crate::__macro_helpers::include_bytes!($($token)*))),
entry_point: "".to_owned(),
// This is unused for SPIR-V
num_workgroups: (0, 0, 0),
reflection: None,
shader_runtime_checks: $crate::ShaderRuntimeChecks::unchecked(),
dxil: None,
msl: None,
hlsl: None,
glsl: None,
wgsl: None,
}
}
};
}