hlsl-out: remap bindings

This commit is contained in:
Dzmitry Malyshau
2021-07-18 00:10:19 -04:00
committed by Dzmitry Malyshau
parent 2a253ab838
commit 51fb9bb77a
4 changed files with 142 additions and 26 deletions

View File

@@ -15,9 +15,22 @@ use thiserror::Error;
pub use writer::Writer;
#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
pub struct BindTarget {
pub space: u8,
pub register: u8,
}
// Using `BTreeMap` instead of `HashMap` so that we can hash itself.
pub type BindingMap = std::collections::BTreeMap<crate::ResourceBinding, BindTarget>;
/// A HLSL shader model version.
#[allow(non_snake_case, non_camel_case_types)]
#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq, PartialOrd)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
pub enum ShaderModel {
V5_0,
V5_1,
@@ -44,28 +57,61 @@ impl crate::ShaderStage {
}
}
#[derive(Clone, Debug, PartialEq, thiserror::Error)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
pub enum EntryPointError {
#[error("mapping of {0:?} is missing")]
MissingBinding(crate::ResourceBinding),
}
/// Structure that contains the configuration used in the [`Writer`](Writer)
#[derive(Debug, Clone)]
#[derive(Clone, Debug, Hash, PartialEq, Eq)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
pub struct Options {
/// The hlsl shader model to be used
pub shader_model: ShaderModel,
/// Map of resources association to binding locations.
pub binding_map: BindingMap,
/// Don't panic on missing bindings, instead generate any HLSL.
pub fake_missing_bindings: bool,
}
impl Default for Options {
fn default() -> Self {
Options {
shader_model: ShaderModel::V5_0,
binding_map: BindingMap::default(),
fake_missing_bindings: true,
}
}
}
impl Options {
fn resolve_resource_binding(
&self,
res_binding: &crate::ResourceBinding,
) -> Result<BindTarget, EntryPointError> {
match self.binding_map.get(res_binding) {
Some(target) => Ok(target.clone()),
None if self.fake_missing_bindings => Ok(BindTarget {
space: res_binding.group as u8,
register: res_binding.binding as u8,
}),
None => Err(EntryPointError::MissingBinding(res_binding.clone())),
}
}
}
/// Structure that contains a reflection info
pub struct ReflectionInfo {
/// Real name of entry point allowed by the `hlsl` compiler.
/// For example:
/// the entry point with the name `line` is valid for `wgsl`, but not valid for `hlsl`, because `line` is a reserved keyword.
pub entry_points: Vec<String>,
// TODO: locations
/// Mapping of the entry point names. Each item in the array
/// corresponds to an entry point index. The real entry point name may be different if one of the
/// reserved words are used.
///
///Note: Some entry points may fail translation because of missing bindings.
pub entry_point_names: Vec<Result<String, EntryPointError>>,
}
#[derive(Error, Debug)]

View File

@@ -64,7 +64,7 @@ impl<'a, W: Write> Writer<'a, W> {
pub fn write(
&mut self,
module: &Module,
info: &valid::ModuleInfo,
module_info: &valid::ModuleInfo,
) -> Result<super::ReflectionInfo, Error> {
self.reset(module);
@@ -144,7 +144,31 @@ impl<'a, W: Write> Writer<'a, W> {
// Write all regular functions
for (handle, function) in module.functions.iter() {
let info = &info[handle];
let info = &module_info[handle];
// Check if all of the globals are accessible
if !self.options.fake_missing_bindings {
if let Some((var_handle, _)) =
module
.global_variables
.iter()
.find(|&(var_handle, var)| match var.binding {
Some(ref binding) if !info[var_handle].is_empty() => {
self.options.resolve_resource_binding(binding).is_err()
}
_ => false,
})
{
log::info!(
"Skipping function {:?} (name {:?}) because global {:?} is inaccessible",
handle,
function.name,
var_handle
);
continue;
}
}
let ctx = back::FunctionCtx {
ty: back::FunctionType::Function(handle),
info,
@@ -161,13 +185,34 @@ impl<'a, W: Write> Writer<'a, W> {
writeln!(self.out)?;
}
let mut entry_points_info = Vec::with_capacity(module.entry_points.len());
let mut entry_point_names = Vec::with_capacity(module.entry_points.len());
// Write all entry points
for (index, ep) in module.entry_points.iter().enumerate() {
let info = module_info.get_entry_point(index);
if !self.options.fake_missing_bindings {
let mut ep_error = None;
for (var_handle, var) in module.global_variables.iter() {
match var.binding {
Some(ref binding) if !info[var_handle].is_empty() => {
if let Err(err) = self.options.resolve_resource_binding(binding) {
ep_error = Some(err);
break;
}
}
_ => {}
}
}
if let Some(err) = ep_error {
entry_point_names.push(Err(err));
continue;
}
}
let ctx = back::FunctionCtx {
ty: back::FunctionType::EntryPoint(index as u16),
info: info.get_entry_point(index),
info,
expressions: &ep.function.expressions,
named_expressions: &ep.function.named_expressions,
};
@@ -176,7 +221,7 @@ impl<'a, W: Write> Writer<'a, W> {
self.write_wrapped_image_query_functions(module, &ctx)?;
if ep.stage == ShaderStage::Compute {
// HLSL is calling workgroup size, num threads
// HLSL is calling workgroup size "num threads"
let num_threads = ep.workgroup_size;
writeln!(
self.out,
@@ -186,19 +231,16 @@ impl<'a, W: Write> Writer<'a, W> {
}
let name = self.names[&NameKey::EntryPoint(index as u16)].clone();
self.write_function(module, &name, &ep.function, &ctx)?;
if index < module.entry_points.len() - 1 {
writeln!(self.out)?;
}
entry_points_info.push(name);
entry_point_names.push(Ok(name));
}
Ok(super::ReflectionInfo {
entry_points: entry_points_info,
})
Ok(super::ReflectionInfo { entry_point_names })
}
fn write_semantic(
@@ -324,9 +366,11 @@ impl<'a, W: Write> Writer<'a, W> {
)?;
if let Some(ref binding) = global.binding {
write!(self.out, " : register({}{}", register_ty, binding.binding)?;
// this was already resolved earlier when we started evaluating an entry point.
let bt = self.options.resolve_resource_binding(binding).unwrap();
write!(self.out, " : register({}{}", register_ty, bt.register)?;
if self.options.shader_model > super::ShaderModel::V5_0 {
write!(self.out, ", space{}", binding.group)?;
write!(self.out, ", space{}", bt.space)?;
}
writeln!(self.out, ");")?;
} else {

View File

@@ -44,4 +44,5 @@
(group: 0, binding: 1): 0,
},
),
hlsl_custom: true,
)

View File

@@ -64,6 +64,12 @@ struct Parameters {
#[cfg_attr(not(feature = "glsl-out"), allow(dead_code))]
#[serde(default)]
glsl_comp_ep_name: Option<String>,
#[cfg(all(feature = "deserialize", feature = "hlsl-out"))]
#[serde(default)]
hlsl: naga::back::hlsl::Options,
#[cfg(all(not(feature = "deserialize"), feature = "hlsl-out"))]
#[serde(default)]
hlsl_custom: bool,
}
#[allow(dead_code, unused_variables)]
@@ -131,7 +137,7 @@ fn check_targets(module: &naga::Module, name: &str, targets: Targets) {
#[cfg(feature = "hlsl-out")]
{
if targets.contains(Targets::HLSL) {
write_output_hlsl(module, &info, &dest, name);
write_output_hlsl(module, &info, &dest, name, &params);
}
}
#[cfg(feature = "wgsl-out")]
@@ -271,11 +277,25 @@ fn write_output_hlsl(
info: &naga::valid::ModuleInfo,
destination: &PathBuf,
file_name: &str,
params: &Parameters,
) {
use naga::back::hlsl;
use std::fmt::Write;
#[cfg_attr(feature = "deserialize", allow(unused_variables))]
let default_options = hlsl::Options::default();
#[cfg(feature = "deserialize")]
let options = &params.hlsl;
#[cfg(not(feature = "deserialize"))]
let options = if params.hlsl_custom {
println!("Skipping {}", destination.display());
return;
} else {
&default_options
};
let mut buffer = String::new();
let options = hlsl::Options::default();
let mut writer = hlsl::Writer::new(&mut buffer, &options);
let mut writer = hlsl::Writer::new(&mut buffer, options);
let reflection_info = writer.write(module, info).unwrap();
fs::write(destination.join(format!("hlsl/{}.hlsl", file_name)), buffer).unwrap();
@@ -283,22 +303,27 @@ fn write_output_hlsl(
// We need a config file for validation script
// This file contains an info about profiles (shader stages) contains inside generated shader
// This info will be passed to dxc
let mut config_str = String::from("");
let mut config_str = String::new();
for (index, ep) in module.entry_points.iter().enumerate() {
let name = match reflection_info.entry_point_names[index] {
Ok(ref name) => name,
Err(_) => continue,
};
let stage_str = match ep.stage {
naga::ShaderStage::Vertex => "vertex",
naga::ShaderStage::Fragment => "fragment",
naga::ShaderStage::Compute => "compute",
};
config_str = format!(
"{}{}={}_{}\n{}_name={}\n",
writeln!(
config_str,
"{}={}_{}\n{}_name={}",
stage_str,
ep.stage.to_hlsl_str(),
options.shader_model.to_str(),
stage_str,
&reflection_info.entry_points[index]
);
name,
)
.unwrap();
}
fs::write(
destination.join(format!("hlsl/{}.hlsl.config", file_name)),