From 51fb9bb77afb919ac712eef566537453fe3877e1 Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Sun, 18 Jul 2021 00:10:19 -0400 Subject: [PATCH] hlsl-out: remap bindings --- src/back/hlsl/mod.rs | 58 +++++++++++++++++++++++++++++---- src/back/hlsl/writer.rs | 68 ++++++++++++++++++++++++++++++++------- tests/in/skybox.param.ron | 1 + tests/snapshots.rs | 41 ++++++++++++++++++----- 4 files changed, 142 insertions(+), 26 deletions(-) diff --git a/src/back/hlsl/mod.rs b/src/back/hlsl/mod.rs index dc98bf85dd..f6ab3450d6 100644 --- a/src/back/hlsl/mod.rs +++ b/src/back/hlsl/mod.rs @@ -15,9 +15,22 @@ use thiserror::Error; pub use writer::Writer; +#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +pub struct BindTarget { + pub space: u8, + pub register: u8, +} + +// Using `BTreeMap` instead of `HashMap` so that we can hash itself. +pub type BindingMap = std::collections::BTreeMap; + /// A HLSL shader model version. #[allow(non_snake_case, non_camel_case_types)] #[derive(Copy, Clone, Debug, Hash, Eq, PartialEq, PartialOrd)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] pub enum ShaderModel { V5_0, V5_1, @@ -44,28 +57,61 @@ impl crate::ShaderStage { } } +#[derive(Clone, Debug, PartialEq, thiserror::Error)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +pub enum EntryPointError { + #[error("mapping of {0:?} is missing")] + MissingBinding(crate::ResourceBinding), +} + /// Structure that contains the configuration used in the [`Writer`](Writer) -#[derive(Debug, Clone)] +#[derive(Clone, Debug, Hash, PartialEq, Eq)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] pub struct Options { /// The hlsl shader model to be used pub shader_model: ShaderModel, + /// Map of resources association to binding locations. + pub binding_map: BindingMap, + /// Don't panic on missing bindings, instead generate any HLSL. + pub fake_missing_bindings: bool, } impl Default for Options { fn default() -> Self { Options { shader_model: ShaderModel::V5_0, + binding_map: BindingMap::default(), + fake_missing_bindings: true, + } + } +} + +impl Options { + fn resolve_resource_binding( + &self, + res_binding: &crate::ResourceBinding, + ) -> Result { + match self.binding_map.get(res_binding) { + Some(target) => Ok(target.clone()), + None if self.fake_missing_bindings => Ok(BindTarget { + space: res_binding.group as u8, + register: res_binding.binding as u8, + }), + None => Err(EntryPointError::MissingBinding(res_binding.clone())), } } } /// Structure that contains a reflection info pub struct ReflectionInfo { - /// Real name of entry point allowed by the `hlsl` compiler. - /// For example: - /// the entry point with the name `line` is valid for `wgsl`, but not valid for `hlsl`, because `line` is a reserved keyword. - pub entry_points: Vec, - // TODO: locations + /// Mapping of the entry point names. Each item in the array + /// corresponds to an entry point index. The real entry point name may be different if one of the + /// reserved words are used. + /// + ///Note: Some entry points may fail translation because of missing bindings. + pub entry_point_names: Vec>, } #[derive(Error, Debug)] diff --git a/src/back/hlsl/writer.rs b/src/back/hlsl/writer.rs index 5b1697c63a..6949b5aef7 100644 --- a/src/back/hlsl/writer.rs +++ b/src/back/hlsl/writer.rs @@ -64,7 +64,7 @@ impl<'a, W: Write> Writer<'a, W> { pub fn write( &mut self, module: &Module, - info: &valid::ModuleInfo, + module_info: &valid::ModuleInfo, ) -> Result { self.reset(module); @@ -144,7 +144,31 @@ impl<'a, W: Write> Writer<'a, W> { // Write all regular functions for (handle, function) in module.functions.iter() { - let info = &info[handle]; + let info = &module_info[handle]; + + // Check if all of the globals are accessible + if !self.options.fake_missing_bindings { + if let Some((var_handle, _)) = + module + .global_variables + .iter() + .find(|&(var_handle, var)| match var.binding { + Some(ref binding) if !info[var_handle].is_empty() => { + self.options.resolve_resource_binding(binding).is_err() + } + _ => false, + }) + { + log::info!( + "Skipping function {:?} (name {:?}) because global {:?} is inaccessible", + handle, + function.name, + var_handle + ); + continue; + } + } + let ctx = back::FunctionCtx { ty: back::FunctionType::Function(handle), info, @@ -161,13 +185,34 @@ impl<'a, W: Write> Writer<'a, W> { writeln!(self.out)?; } - let mut entry_points_info = Vec::with_capacity(module.entry_points.len()); + let mut entry_point_names = Vec::with_capacity(module.entry_points.len()); // Write all entry points for (index, ep) in module.entry_points.iter().enumerate() { + let info = module_info.get_entry_point(index); + + if !self.options.fake_missing_bindings { + let mut ep_error = None; + for (var_handle, var) in module.global_variables.iter() { + match var.binding { + Some(ref binding) if !info[var_handle].is_empty() => { + if let Err(err) = self.options.resolve_resource_binding(binding) { + ep_error = Some(err); + break; + } + } + _ => {} + } + } + if let Some(err) = ep_error { + entry_point_names.push(Err(err)); + continue; + } + } + let ctx = back::FunctionCtx { ty: back::FunctionType::EntryPoint(index as u16), - info: info.get_entry_point(index), + info, expressions: &ep.function.expressions, named_expressions: &ep.function.named_expressions, }; @@ -176,7 +221,7 @@ impl<'a, W: Write> Writer<'a, W> { self.write_wrapped_image_query_functions(module, &ctx)?; if ep.stage == ShaderStage::Compute { - // HLSL is calling workgroup size, num threads + // HLSL is calling workgroup size "num threads" let num_threads = ep.workgroup_size; writeln!( self.out, @@ -186,19 +231,16 @@ impl<'a, W: Write> Writer<'a, W> { } let name = self.names[&NameKey::EntryPoint(index as u16)].clone(); - self.write_function(module, &name, &ep.function, &ctx)?; if index < module.entry_points.len() - 1 { writeln!(self.out)?; } - entry_points_info.push(name); + entry_point_names.push(Ok(name)); } - Ok(super::ReflectionInfo { - entry_points: entry_points_info, - }) + Ok(super::ReflectionInfo { entry_point_names }) } fn write_semantic( @@ -324,9 +366,11 @@ impl<'a, W: Write> Writer<'a, W> { )?; if let Some(ref binding) = global.binding { - write!(self.out, " : register({}{}", register_ty, binding.binding)?; + // this was already resolved earlier when we started evaluating an entry point. + let bt = self.options.resolve_resource_binding(binding).unwrap(); + write!(self.out, " : register({}{}", register_ty, bt.register)?; if self.options.shader_model > super::ShaderModel::V5_0 { - write!(self.out, ", space{}", binding.group)?; + write!(self.out, ", space{}", bt.space)?; } writeln!(self.out, ");")?; } else { diff --git a/tests/in/skybox.param.ron b/tests/in/skybox.param.ron index 66899e280d..03e56dde94 100644 --- a/tests/in/skybox.param.ron +++ b/tests/in/skybox.param.ron @@ -44,4 +44,5 @@ (group: 0, binding: 1): 0, }, ), + hlsl_custom: true, ) diff --git a/tests/snapshots.rs b/tests/snapshots.rs index e634ee851d..da5bd50475 100644 --- a/tests/snapshots.rs +++ b/tests/snapshots.rs @@ -64,6 +64,12 @@ struct Parameters { #[cfg_attr(not(feature = "glsl-out"), allow(dead_code))] #[serde(default)] glsl_comp_ep_name: Option, + #[cfg(all(feature = "deserialize", feature = "hlsl-out"))] + #[serde(default)] + hlsl: naga::back::hlsl::Options, + #[cfg(all(not(feature = "deserialize"), feature = "hlsl-out"))] + #[serde(default)] + hlsl_custom: bool, } #[allow(dead_code, unused_variables)] @@ -131,7 +137,7 @@ fn check_targets(module: &naga::Module, name: &str, targets: Targets) { #[cfg(feature = "hlsl-out")] { if targets.contains(Targets::HLSL) { - write_output_hlsl(module, &info, &dest, name); + write_output_hlsl(module, &info, &dest, name, ¶ms); } } #[cfg(feature = "wgsl-out")] @@ -271,11 +277,25 @@ fn write_output_hlsl( info: &naga::valid::ModuleInfo, destination: &PathBuf, file_name: &str, + params: &Parameters, ) { use naga::back::hlsl; + use std::fmt::Write; + + #[cfg_attr(feature = "deserialize", allow(unused_variables))] + let default_options = hlsl::Options::default(); + #[cfg(feature = "deserialize")] + let options = ¶ms.hlsl; + #[cfg(not(feature = "deserialize"))] + let options = if params.hlsl_custom { + println!("Skipping {}", destination.display()); + return; + } else { + &default_options + }; + let mut buffer = String::new(); - let options = hlsl::Options::default(); - let mut writer = hlsl::Writer::new(&mut buffer, &options); + let mut writer = hlsl::Writer::new(&mut buffer, options); let reflection_info = writer.write(module, info).unwrap(); fs::write(destination.join(format!("hlsl/{}.hlsl", file_name)), buffer).unwrap(); @@ -283,22 +303,27 @@ fn write_output_hlsl( // We need a config file for validation script // This file contains an info about profiles (shader stages) contains inside generated shader // This info will be passed to dxc - let mut config_str = String::from(""); + let mut config_str = String::new(); for (index, ep) in module.entry_points.iter().enumerate() { + let name = match reflection_info.entry_point_names[index] { + Ok(ref name) => name, + Err(_) => continue, + }; let stage_str = match ep.stage { naga::ShaderStage::Vertex => "vertex", naga::ShaderStage::Fragment => "fragment", naga::ShaderStage::Compute => "compute", }; - config_str = format!( - "{}{}={}_{}\n{}_name={}\n", + writeln!( config_str, + "{}={}_{}\n{}_name={}", stage_str, ep.stage.to_hlsl_str(), options.shader_model.to_str(), stage_str, - &reflection_info.entry_points[index] - ); + name, + ) + .unwrap(); } fs::write( destination.join(format!("hlsl/{}.hlsl.config", file_name)),