diff --git a/Makefile b/Makefile index 8452f2ede4..30da29c4cf 100644 --- a/Makefile +++ b/Makefile @@ -65,7 +65,15 @@ validate-wgsl: $(SNAPSHOTS_OUT)/*.wgsl done validate-hlsl: $(SNAPSHOTS_OUT)/*.hlsl - @set -e && for file in $(SNAPSHOTS_OUT)/*.Compute.hlsl ; do \ - echo "Validating" $${file#"$(SNAPSHOTS_OUT)/"};\ - dxc $${file} -T cs_5_0;\ + @set -e && for file in $^ ; do \ + echo "Validating" $${file#"$(SNAPSHOTS_OUT)/"}; \ + config="$$(dirname $${file})/$$(basename $${file}).config"; \ + vertex=""\ + fragment="" \ + compute="" \ + . $${config}; \ + [ ! -z "$${vertex}" ] && echo "Vertex Stage:" && dxc $${file} -T $${vertex} -E $${vertex_name} -Wno-parentheses-equality -Zi -Qembed_debug > /dev/null; \ + [ ! -z "$${fragment}" ] && echo "Fragment Stage:" && dxc $${file} -T $${fragment} -E $${fragment_name} -Wno-parentheses-equality -Zi -Qembed_debug > /dev/null; \ + [ ! -z "$${compute}" ] && echo "Compute Stage:" && dxc $${file} -T $${compute} -E $${compute_name} -Wno-parentheses-equality -Zi -Qembed_debug > /dev/null; \ + echo "======================"; \ done diff --git a/cli/src/main.rs b/cli/src/main.rs index 3935950208..14d4c11853 100644 --- a/cli/src/main.rs +++ b/cli/src/main.rs @@ -11,6 +11,7 @@ struct Parameters { spv: naga::back::spv::Options, msl: naga::back::msl::Options, glsl: naga::back::glsl::Options, + hlsl: naga::back::hlsl::Options, } trait PrettyResult { @@ -75,6 +76,12 @@ fn main() { panic!("Unknown profile: {}", string) }; } + "shader-model" => { + use naga::back::hlsl::{ShaderModel, DEFAULT_SHADER_MODEL}; + let string = args.next().unwrap(); + params.hlsl.shader_model = + ShaderModel::new(string.parse().unwrap_or(DEFAULT_SHADER_MODEL)); + } other => log::warn!("Unknown parameter: {}", other), } } else if input_path.is_none() { @@ -251,13 +258,8 @@ fn main() { } "hlsl" => { use naga::back::hlsl; - // TODO: Get `ShaderModel` from user - let hlsl = hlsl::write_string( - &module, - info.as_ref().unwrap(), - hlsl::ShaderModel::default(), - ) - .unwrap_pretty(); + let hlsl = hlsl::write_string(&module, info.as_ref().unwrap(), ¶ms.hlsl) + .unwrap_pretty(); fs::write(output_path, hlsl).unwrap(); } "wgsl" => { diff --git a/src/back/hlsl/mod.rs b/src/back/hlsl/mod.rs index 5f049a4a99..3b94239a05 100644 --- a/src/back/hlsl/mod.rs +++ b/src/back/hlsl/mod.rs @@ -6,12 +6,43 @@ use thiserror::Error; pub use writer::Writer; +pub const DEFAULT_SHADER_MODEL: u16 = 50; #[derive(Debug, Copy, Clone, PartialEq, PartialOrd)] pub struct ShaderModel(u16); +impl ShaderModel { + pub fn new(shader_model: u16) -> Self { + Self(shader_model) + } +} + impl Default for ShaderModel { fn default() -> Self { - ShaderModel(50) + Self(DEFAULT_SHADER_MODEL) + } +} + +/// Structure that contains the configuration used in the [`Writer`](Writer) +#[derive(Debug, Clone)] +pub struct Options { + /// The hlsl shader model to be used + pub shader_model: ShaderModel, + /// The vertex entry point name in generated shader + pub vertex_entry_point_name: String, + /// The fragment entry point name in generated shader + pub fragment_entry_point_name: String, + /// The comput entry point name in generated shader + pub compute_entry_point_name: String, +} + +impl Default for Options { + fn default() -> Self { + Options { + shader_model: ShaderModel(50), + vertex_entry_point_name: String::from("vert_main"), + fragment_entry_point_name: String::from("frag_main"), + compute_entry_point_name: String::from("comp_main"), + } } } @@ -23,18 +54,18 @@ pub enum Error { UnsupportedShaderModel(ShaderModel), #[error("A scalar with an unsupported width was requested: {0:?} {1:?}")] UnsupportedScalar(crate::ScalarKind, crate::Bytes), - #[error("BuiltIn {0:?} is not supported")] - UnsupportedBuiltIn(crate::BuiltIn), #[error("{0}")] Unimplemented(String), // TODO: Error used only during development + #[error("{0}")] + Custom(String), } pub fn write_string( module: &crate::Module, info: &crate::valid::ModuleInfo, - shader_model: ShaderModel, + options: &Options, ) -> Result { - let mut w = Writer::new(String::new(), shader_model); + let mut w = Writer::new(String::new(), options); w.write(module, info)?; let output = w.finish(); Ok(output) diff --git a/src/back/hlsl/writer.rs b/src/back/hlsl/writer.rs index 78e0861372..c2260cf187 100644 --- a/src/back/hlsl/writer.rs +++ b/src/back/hlsl/writer.rs @@ -1,16 +1,19 @@ //TODO: temp #![allow(dead_code)] -use super::{Error, ShaderModel}; -use crate::back::hlsl::keywords::RESERVED; -use crate::proc::{EntryPointIndex, NameKey, Namer}; -use crate::valid::{FunctionInfo, ModuleInfo}; +use super::{Error, Options, ShaderModel}; use crate::{ - Arena, Bytes, Constant, Expression, FastHashMap, Function, Handle, ImageDimension, - LocalVariable, Module, ScalarKind, ShaderStage, Statement, Type, TypeInner, + back::{hlsl::keywords::RESERVED, vector_size_str}, + proc::{EntryPointIndex, NameKey, Namer, TypeResolution}, + valid::{FunctionInfo, ModuleInfo}, + Arena, BuiltIn, Bytes, Constant, ConstantInner, Expression, FastHashMap, Function, + GlobalVariable, Handle, ImageDimension, LocalVariable, Module, ScalarKind, ScalarValue, + ShaderStage, Statement, StructMember, Type, TypeInner, }; use std::fmt::Write; const INDENT: &str = " "; +const COMPONENTS: &[char] = &['x', 'y', 'z', 'w']; +const LOCATION_SEMANTIC: &str = "LOC"; /// Shorthand result used internally by the backend type BackendResult = Result<(), Error>; @@ -49,21 +52,35 @@ impl<'a> FunctionCtx<'_> { } } -pub struct Writer { +struct EntryPointBinding { + stage: ShaderStage, + name: String, + members: Vec, +} + +struct EpStructMember { + pub name: String, + pub ty: Handle, + pub binding: Option, +} + +pub struct Writer<'a, W> { out: W, names: FastHashMap, namer: Namer, - shader_model: ShaderModel, + options: &'a Options, + ep_inputs: Vec>, named_expressions: crate::NamedExpressions, } -impl Writer { - pub fn new(out: W, shader_model: ShaderModel) -> Self { - Writer { +impl<'a, W: Write> Writer<'a, W> { + pub fn new(out: W, options: &'a Options) -> Self { + Self { out, names: FastHashMap::default(), namer: Namer::default(), - shader_model, + options, + ep_inputs: Vec::with_capacity(3), named_expressions: crate::NamedExpressions::default(), } } @@ -72,23 +89,76 @@ impl Writer { self.names.clear(); self.namer.reset(module, RESERVED, &[], &mut self.names); self.named_expressions.clear(); + self.ep_inputs.clear(); } pub fn write(&mut self, module: &Module, info: &ModuleInfo) -> BackendResult { - if self.shader_model < ShaderModel(50) { - return Err(Error::UnsupportedShaderModel(self.shader_model)); + if self.options.shader_model < ShaderModel::default() { + return Err(Error::UnsupportedShaderModel(self.options.shader_model)); } self.reset(module); + // Write all constants + for (handle, constant) in module.constants.iter() { + if constant.name.is_some() { + self.write_global_constant(module, &constant.inner, handle)?; + } + } + + // Write all globals + for (ty, _) in module.global_variables.iter() { + self.write_global(module, ty)?; + } + + if !module.global_variables.is_empty() { + // Add extra newline for readability + writeln!(self.out)?; + } + + // Write all structs + for (handle, ty) in module.types.iter() { + if let TypeInner::Struct { + top_level, + ref members, + .. + } = ty.inner + { + self.write_struct(module, handle, top_level, members)?; + writeln!(self.out)?; + } + } + + // Write all entry points wrapped structs + for (index, ep) in module.entry_points.iter().enumerate() { + self.write_ep_input_struct(module, &ep.function, ep.stage, index)?; + } + + // Write all regular functions + for (handle, function) in module.functions.iter() { + let info = &info[handle]; + let ctx = FunctionCtx { + ty: FunctionType::Function(handle), + info, + expressions: &function.expressions, + named_expressions: &function.named_expressions, + }; + let name = self.names[&NameKey::Function(handle)].clone(); + + self.write_function(module, name.as_str(), function, &ctx)?; + + writeln!(self.out)?; + } + // Write all entry points for (index, ep) in module.entry_points.iter().enumerate() { - let func_ctx = FunctionCtx { + let ctx = FunctionCtx { ty: FunctionType::EntryPoint(index as u16), info: info.get_entry_point(index), expressions: &ep.function.expressions, named_expressions: &ep.function.named_expressions, }; + if ep.stage == ShaderStage::Compute { // HLSL is calling workgroup size, num threads let num_threads = ep.workgroup_size; @@ -98,7 +168,14 @@ impl Writer { num_threads[0], num_threads[1], num_threads[2] )?; } - self.write_function(module, &ep.function, &func_ctx)?; + + let name = match ep.stage { + ShaderStage::Vertex => &self.options.vertex_entry_point_name, + ShaderStage::Fragment => &self.options.fragment_entry_point_name, + ShaderStage::Compute => &self.options.compute_entry_point_name, + }; + + self.write_function(module, name, &ep.function, &ctx)?; if index < module.entry_points.len() - 1 { writeln!(self.out)?; @@ -108,6 +185,183 @@ impl Writer { Ok(()) } + fn write_binding(&mut self, binding: &crate::Binding) -> BackendResult { + match *binding { + crate::Binding::BuiltIn(builtin) => { + write!(self.out, " : {}", builtin_str(builtin))?; + } + crate::Binding::Location { location, .. } => { + write!(self.out, " : {}{}", LOCATION_SEMANTIC, location)?; + } + } + + Ok(()) + } + + fn write_ep_input_struct( + &mut self, + module: &Module, + func: &Function, + stage: ShaderStage, + index: usize, + ) -> BackendResult { + if !func.arguments.is_empty() { + let struct_name = self.namer.call_unique(match stage { + ShaderStage::Vertex => "VertexInput", + ShaderStage::Fragment => "FragmentInput", + ShaderStage::Compute => "ComputeInput", + }); + + let mut members = Vec::with_capacity(func.arguments.len()); + + write!(self.out, "struct {}", &struct_name)?; + writeln!(self.out, " {{")?; + + for arg in func.arguments.iter() { + let member_name = if let Some(ref name) = arg.name { + name + } else { + "member" + }; + let member = EpStructMember { + name: self.namer.call_unique(member_name), + ty: arg.ty, + binding: arg.binding.clone(), + }; + + write!(self.out, "{}", INDENT)?; + self.write_type(module, member.ty)?; + write!(self.out, " {}", &member.name)?; + if let Some(ref binding) = member.binding { + self.write_binding(binding)?; + } + write!(self.out, ";")?; + writeln!(self.out)?; + + members.push(member); + } + + writeln!(self.out, "}};")?; + writeln!(self.out)?; + + let ep_input = EntryPointBinding { + stage, + name: struct_name, + members, + }; + + self.ep_inputs.insert(index, Some(ep_input)); + } + + Ok(()) + } + + /// Helper method used to write global variables + /// # Notes + /// Always adds a newline + fn write_global(&mut self, module: &Module, handle: Handle) -> BackendResult { + let global = &module.global_variables[handle]; + let inner = &module.types[global.ty].inner; + + let register_ty = match *inner { + TypeInner::Image { .. } => "t", + TypeInner::Sampler { .. } => "s", + // TODO: other register ty https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-variable-register + _ => return Err(Error::Unimplemented(format!("register_ty {:?}", inner))), + }; + + let register = if let Some(ref binding) = global.binding { + format!("register({}{})", register_ty, binding.binding) + } else { + String::from("") + }; + + let name = self.names[&NameKey::GlobalVariable(handle)].clone(); + self.write_type(module, global.ty)?; + write!(self.out, " {}", name)?; + + if register.is_empty() { + writeln!(self.out, ";")?; + } else { + writeln!(self.out, " : {};", register)?; + } + + Ok(()) + } + + /// Helper method used to write global constants + /// + /// # Notes + /// Ends in a newline + fn write_global_constant( + &mut self, + _module: &Module, + inner: &ConstantInner, + handle: Handle, + ) -> BackendResult { + match *inner { + ConstantInner::Scalar { + width: _, + ref value, + } => { + let name = &self.names[&NameKey::Constant(handle)]; + let (ty, value) = match *value { + crate::ScalarValue::Sint(value) => ("int", format!("{}", value)), + crate::ScalarValue::Uint(value) => ("uint", format!("{}", value)), + crate::ScalarValue::Float(value) => { + // Floats are written using `Debug` instead of `Display` because it always appends the + // decimal part even it's zero + ("float", format!("{:?}", value)) + } + crate::ScalarValue::Bool(value) => ("bool", format!("{}", value)), + }; + writeln!(self.out, "static const {} {} = {};", ty, name, value)?; + } + ConstantInner::Composite { .. } => { + return Err(Error::Unimplemented(format!( + "write_global_constant Composite {:?}", + inner + ))) + } + } + // End with extra newline for readability + writeln!(self.out)?; + Ok(()) + } + + /// Helper method used to write structs + /// + /// # Notes + /// Ends in a newline + fn write_struct( + &mut self, + module: &Module, + handle: Handle, + _block: bool, + members: &[StructMember], + ) -> BackendResult { + // Write struct name + write!(self.out, "struct {}", self.names[&NameKey::Type(handle)])?; + writeln!(self.out, " {{")?; + + for (index, member) in members.iter().enumerate() { + // The indentation is only for readability + write!(self.out, "{}", INDENT)?; + // Write struct member type and name + self.write_type(module, member.ty)?; + let member_name = &self.names[&NameKey::StructMember(handle, index as u32)]; + write!(self.out, " {}", member_name)?; + if let Some(ref binding) = member.binding { + self.write_binding(binding)?; + }; + write!(self.out, ";")?; + writeln!(self.out)?; + } + + writeln!(self.out, "}};")?; + Ok(()) + } + /// Helper method used to write non image/sampler types /// /// # Notes @@ -131,6 +385,52 @@ impl Writer { TypeInner::Scalar { kind, width } => { write!(self.out, "{}", scalar_kind_str(kind, width)?)?; } + TypeInner::Vector { size, kind, width } => { + write!( + self.out, + "{}{}", + scalar_kind_str(kind, width)?, + vector_size_str(size) + )?; + } + TypeInner::Matrix { + columns, + rows, + width, + } => { + //TODO: int matrix ? + // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-matrix + write!( + self.out, + "{}{}x{}", + scalar_kind_str(ScalarKind::Float, width)?, + vector_size_str(columns), + vector_size_str(rows), + )?; + } + TypeInner::Image { + dim, + arrayed: _, //TODO: + class, + } => { + let dim_str = image_dimension_str(dim); + if let crate::ImageClass::Sampled { kind, multi: false } = class { + write!( + self.out, + "Texture{}<{}4>", + dim_str, + scalar_kind_str(kind, 4)? + )? + } else { + return Err(Error::Unimplemented(format!( + "write_value_type {:?}", + inner + ))); + } + } + TypeInner::Sampler { comparison: false } => { + write!(self.out, "SamplerState")?; + } _ => { return Err(Error::Unimplemented(format!( "write_value_type {:?}", @@ -148,6 +448,7 @@ impl Writer { fn write_function( &mut self, module: &Module, + name: &str, func: &Function, func_ctx: &FunctionCtx<'_>, ) -> BackendResult { @@ -158,51 +459,60 @@ impl Writer { write!(self.out, "void")?; } - let func_name = match func_ctx.ty { - FunctionType::EntryPoint(index) => self.names[&NameKey::EntryPoint(index)].clone(), - FunctionType::Function(handle) => self.names[&NameKey::Function(handle)].clone(), - }; - // Write function name - write!(self.out, " {}(", func_name)?; + write!(self.out, " {}(", name)?; - // Write function arguments - for (index, arg) in func.arguments.iter().enumerate() { - // Write argument type - self.write_type(module, arg.ty)?; + // Write function arguments for non entry point functions + match func_ctx.ty { + FunctionType::Function(handle) => { + for (index, arg) in func.arguments.iter().enumerate() { + // Write argument type + self.write_type(module, arg.ty)?; - let argument_name = match func_ctx.ty { - FunctionType::Function(handle) => { - self.names[&NameKey::FunctionArgument(handle, index as u32)].clone() + let argument_name = + &self.names[&NameKey::FunctionArgument(handle, index as u32)]; + + // Write argument name. Space is important. + write!(self.out, " {}", argument_name)?; + if index < func.arguments.len() - 1 { + // Add a separator between args + write!(self.out, ", ")?; + } } - FunctionType::EntryPoint(ep_index) => { - self.names[&NameKey::EntryPointArgument(ep_index, index as u32)].clone() + } + FunctionType::EntryPoint(index) => { + // EntryPoint arguments wrapped into structure + if !self.ep_inputs.is_empty() { + if let Some(ref ep_input) = self.ep_inputs[index as usize] { + write!( + self.out, + "{} {}", + ep_input.name, + self.namer + .call_unique(ep_input.name.to_lowercase().as_str()) + )?; + } } - }; - - // Write argument name. Space is important. - write!(self.out, " {}", argument_name)?; - if index < func.arguments.len() - 1 { - // Add a separator between args - write!(self.out, ", ")?; } } // Ends of arguments write!(self.out, ")")?; // Write semantic if it present + let stage = match func_ctx.ty { + FunctionType::EntryPoint(index) => Some(module.entry_points[index as usize].stage), + _ => None, + }; if let Some(ref result) = func.result { if let Some(ref binding) = result.binding { match *binding { crate::Binding::BuiltIn(builtin) => { - write!(self.out, " : {}", builtin_str(builtin)?)? + write!(self.out, " : {}", builtin_str(builtin))?; } - // TODO: Is this reachable ? - crate::Binding::Location { .. } => { - return Err(Error::Unimplemented(format!( - "write_function semantic {:?}", - binding - ))) + crate::Binding::Location { location, .. } => { + if stage == Some(ShaderStage::Fragment) { + write!(self.out, " : SV_Target{}", location)?; + } } } } @@ -267,14 +577,105 @@ impl Writer { indent: usize, ) -> BackendResult { match *stmt { - Statement::Return { value } => { - write!(self.out, "{}return", INDENT.repeat(indent))?; - if let Some(return_value) = value { - // The leading space is important - write!(self.out, " ")?; - self.write_expr(module, return_value, func_ctx)?; + Statement::Emit(ref range) => { + for handle in range.clone() { + let expr_name = if let Some(name) = func_ctx.named_expressions.get(&handle) { + // Front end provides names for all variables at the start of writing. + // But we write them to step by step. We need to recache them + // Otherwise, we could accidentally write variable name instead of full expression. + // Also, we use sanitized names! It defense backend from generating variable with name from reserved keywords. + Some(self.namer.call_unique(name)) + } else { + let min_ref_count = func_ctx.expressions[handle].bake_ref_count(); + if min_ref_count <= func_ctx.info[handle].ref_count { + Some(format!("_expr{}", handle.index())) + } else { + None + } + }; + + if let Some(name) = expr_name { + write!(self.out, "{}", INDENT.repeat(indent))?; + self.write_named_expr(module, handle, name, func_ctx)?; + } + } + } + // TODO: copy-paste from glsl-out + Statement::Block(ref block) => { + write!(self.out, "{}", INDENT.repeat(indent))?; + writeln!(self.out, "{{")?; + for sta in block.iter() { + // Increase the indentation to help with readability + self.write_stmt(module, sta, func_ctx, indent + 1)? + } + writeln!(self.out, "{}}}", INDENT.repeat(indent))? + } + // TODO: copy-paste from glsl-out + Statement::If { + condition, + ref accept, + ref reject, + } => { + write!(self.out, "{}", INDENT.repeat(indent))?; + write!(self.out, "if (")?; + self.write_expr(module, condition, func_ctx)?; + writeln!(self.out, ") {{")?; + + for sta in accept { + // Increase indentation to help with readability + self.write_stmt(module, sta, func_ctx, indent + 1)?; + } + + // If there are no statements in the reject block we skip writing it + // This is only for readability + if !reject.is_empty() { + writeln!(self.out, "{}}} else {{", INDENT.repeat(indent))?; + + for sta in reject { + // Increase indentation to help with readability + self.write_stmt(module, sta, func_ctx, indent + 1)?; + } + } + + writeln!(self.out, "{}}}", INDENT.repeat(indent))? + } + // TODO: copy-paste from glsl-out + Statement::Kill => writeln!(self.out, "{}discard;", INDENT.repeat(indent))?, + Statement::Return { value: None } => { + writeln!(self.out, "{}return;", INDENT.repeat(indent))?; + } + Statement::Return { value: Some(expr) } => { + let base_ty_res = &func_ctx.info[expr].ty; + let mut resolved = base_ty_res.inner_with(&module.types); + if let TypeInner::Pointer { base, class: _ } = *resolved { + resolved = &module.types[base].inner; + } + + if let TypeInner::Struct { .. } = *resolved { + // We can safery unwrap here, since we now we working with struct + let ty = base_ty_res.handle().unwrap(); + let struct_name = &self.names[&NameKey::Type(ty)]; + let variable_name = self.namer.call_unique(struct_name.as_str()).to_lowercase(); + write!( + self.out, + "{}const {} {} = ", + INDENT.repeat(indent), + struct_name, + variable_name + )?; + self.write_expr(module, expr, func_ctx)?; + writeln!(self.out)?; + writeln!( + self.out, + "{}return {};", + INDENT.repeat(indent), + variable_name + )?; + } else { + write!(self.out, "{}return ", INDENT.repeat(indent))?; + self.write_expr(module, expr, func_ctx)?; + writeln!(self.out, ";")? } - writeln!(self.out, ";")?; } _ => return Err(Error::Unimplemented(format!("write_stmt {:?}", stmt))), } @@ -288,7 +689,7 @@ impl Writer { /// Doesn't add any newlines or leading/trailing spaces fn write_expr( &mut self, - _module: &Module, + module: &Module, expr: Handle, func_ctx: &FunctionCtx<'_>, ) -> BackendResult { @@ -299,10 +700,119 @@ impl Writer { let expression = &func_ctx.expressions[expr]; - #[allow(clippy::match_single_binding)] match *expression { + Expression::Constant(constant) => self.write_constant(module, constant)?, + Expression::Compose { ty, ref components } => { + let is_struct = if let TypeInner::Struct { .. } = module.types[ty].inner { + true + } else { + false + }; + if is_struct { + write!(self.out, "{{ ")?; + } else { + self.write_type(module, ty)?; + write!(self.out, "(")?; + } + for (index, component) in components.iter().enumerate() { + self.write_expr(module, *component, func_ctx)?; + // Only write a comma if isn't the last element + if index != components.len().saturating_sub(1) { + // The leading space is for readability only + write!(self.out, ", ")?; + } + } + if is_struct { + write!(self.out, " }};")? + } else { + write!(self.out, ")")? + } + } + // TODO: copy-paste from wgsl-out + Expression::Binary { op, left, right } => { + write!(self.out, "(")?; + self.write_expr(module, left, func_ctx)?; + write!(self.out, " {} ", crate::back::binary_operation_str(op))?; + self.write_expr(module, right, func_ctx)?; + write!(self.out, ")")?; + } + // TODO: copy-paste from glsl-out + Expression::AccessIndex { base, index } => { + self.write_expr(module, base, func_ctx)?; + + let base_ty_res = &func_ctx.info[base].ty; + let mut resolved = base_ty_res.inner_with(&module.types); + let base_ty_handle = match *resolved { + TypeInner::Pointer { base, class: _ } => { + resolved = &module.types[base].inner; + Some(base) + } + _ => base_ty_res.handle(), + }; + + match *resolved { + TypeInner::Vector { .. } => { + // Write vector access as a swizzle + write!(self.out, ".{}", COMPONENTS[index as usize])? + } + TypeInner::Matrix { .. } + | TypeInner::Array { .. } + | TypeInner::ValuePointer { .. } => write!(self.out, "[{}]", index)?, + TypeInner::Struct { .. } => { + // This will never panic in case the type is a `Struct`, this is not true + // for other types so we can only check while inside this match arm + let ty = base_ty_handle.unwrap(); + + write!( + self.out, + ".{}", + &self.names[&NameKey::StructMember(ty, index)] + )? + } + ref other => return Err(Error::Custom(format!("Cannot index {:?}", other))), + } + } + Expression::FunctionArgument(pos) => { + let name = match func_ctx.ty { + FunctionType::Function(handle) => { + self.names[&NameKey::FunctionArgument(handle, pos)].clone() + } + FunctionType::EntryPoint(index) => { + // EntryPoint arguments wrapped into structure + // We can safery unwrap here, because if we write function arguments it means, that ep_input struct already exists + let ep_input = self.ep_inputs[index as usize].as_ref().unwrap(); + let member_name = &ep_input.members[pos as usize].name; + format!("{}.{}", &ep_input.name.to_lowercase(), member_name) + } + }; + write!(self.out, "{}", name)?; + } + Expression::ImageSample { + image, + sampler, // TODO: + coordinate, // TODO: + array_index: _, // TODO: + offset: _, // TODO: + level: _, // TODO: + depth_ref: _, // TODO: + } => { + // TODO: others + self.write_expr(module, image, func_ctx)?; + write!(self.out, ".Sample(")?; + self.write_expr(module, sampler, func_ctx)?; + write!(self.out, ", ")?; + self.write_expr(module, coordinate, func_ctx)?; + write!(self.out, ")")?; + } + // TODO: copy-paste from wgsl-out + Expression::GlobalVariable(handle) => { + let name = &self.names[&NameKey::GlobalVariable(handle)]; + write!(self.out, "{}", name)?; + } _ => return Err(Error::Unimplemented(format!("write_expr {:?}", expression))), } + + Ok(()) } /// Helper method used to write constants @@ -311,8 +821,13 @@ impl Writer { /// Doesn't add any newlines or leading/trailing spaces fn write_constant(&mut self, module: &Module, handle: Handle) -> BackendResult { let constant = &module.constants[handle]; - #[allow(clippy::match_single_binding)] match constant.inner { + crate::ConstantInner::Scalar { + width: _, + ref value, + } => { + self.write_scalar_value(*value)?; + } _ => { return Err(Error::Unimplemented(format!( "write_constant {:?}", @@ -320,6 +835,65 @@ impl Writer { ))) } } + + Ok(()) + } + + /// Helper method used to write [`ScalarValue`](ScalarValue) + /// + /// # Notes + /// Adds no trailing or leading whitespace + fn write_scalar_value(&mut self, value: ScalarValue) -> BackendResult { + match value { + ScalarValue::Sint(value) => write!(self.out, "{}", value)?, + ScalarValue::Uint(value) => write!(self.out, "{}u", value)?, + // Floats are written using `Debug` instead of `Display` because it always appends the + // decimal part even it's zero + ScalarValue::Float(value) => write!(self.out, "{:?}", value)?, + ScalarValue::Bool(value) => write!(self.out, "{}", value)?, + } + + Ok(()) + } + + fn write_named_expr( + &mut self, + module: &Module, + handle: Handle, + name: String, + ctx: &FunctionCtx, + ) -> BackendResult { + match ctx.info[handle].ty { + TypeResolution::Handle(ty_handle) => match module.types[ty_handle].inner { + TypeInner::Struct { .. } => { + let ty_name = &self.names[&NameKey::Type(ty_handle)]; + write!(self.out, "{}", ty_name)?; + } + _ => { + self.write_type(module, ty_handle)?; + } + }, + TypeResolution::Value(ref inner) => { + self.write_value_type(module, inner)?; + } + } + + let base_ty_res = &ctx.info[handle].ty; + let resolved = base_ty_res.inner_with(&module.types); + + // If rhs is a array type, we should write temp variable as a dynamic array + let array_str = if let TypeInner::Array { .. } = *resolved { + "[]" + } else { + "" + }; + + write!(self.out, " {}{} = ", name, array_str)?; + self.write_expr(module, handle, ctx)?; + writeln!(self.out, ";")?; + self.named_expressions.insert(handle, name); + + Ok(()) } pub fn finish(self) -> W { @@ -336,12 +910,28 @@ fn image_dimension_str(dim: ImageDimension) -> &'static str { } } -fn builtin_str(built_in: crate::BuiltIn) -> Result<&'static str, Error> { - use crate::BuiltIn; +fn builtin_str(built_in: BuiltIn) -> &'static str { match built_in { - BuiltIn::Position => Ok("SV_Position"), - BuiltIn::PointSize => Err(Error::UnsupportedBuiltIn(built_in)), - _ => Err(Error::Unimplemented(format!("builtin_str {:?}", built_in))), + BuiltIn::Position => "SV_Position", + // vertex + BuiltIn::ClipDistance => "SV_ClipDistance", + BuiltIn::CullDistance => "SV_CullDistance", + BuiltIn::InstanceIndex => "SV_InstanceID", + // based on this page https://docs.microsoft.com/en-us/windows/uwp/gaming/glsl-to-hlsl-reference#comparing-opengl-es-20-with-direct3d-11 + // No meaning unless you target Direct3D 9 + BuiltIn::PointSize => "PSIZE", + BuiltIn::VertexIndex => "SV_VertexID", + // fragment + BuiltIn::FragDepth => "SV_Depth", + BuiltIn::FrontFacing => "SV_IsFrontFace", + BuiltIn::SampleIndex => "SV_SampleIndex", + BuiltIn::SampleMask => "SV_Coverage", + // compute + BuiltIn::GlobalInvocationId => "SV_DispatchThreadID", + BuiltIn::LocalInvocationId => "SV_GroupThreadID", + BuiltIn::LocalInvocationIndex => "SV_GroupIndex", + BuiltIn::WorkGroupId => "SV_GroupID", + _ => todo!("builtin_str {:?}", built_in), } } diff --git a/src/back/mod.rs b/src/back/mod.rs index 63f84d8f03..49c7c56896 100644 --- a/src/back/mod.rs +++ b/src/back/mod.rs @@ -40,7 +40,7 @@ impl crate::Expression { /// Helper function that returns the string corresponding to the [`BinaryOperator`](crate::BinaryOperator) /// # Notes -/// Used by `glsl-out`, `msl-out`, `wgsl-out`. +/// Used by `glsl-out`, `msl-out`, `wgsl-out`, `hlsl-out`. #[allow(dead_code)] fn binary_operation_str(op: crate::BinaryOperator) -> &'static str { use crate::BinaryOperator as Bo; @@ -68,7 +68,7 @@ fn binary_operation_str(op: crate::BinaryOperator) -> &'static str { /// Helper function that returns the string corresponding to the [`VectorSize`](crate::VectorSize) /// # Notes -/// Used by `msl-out`, `wgsl-out`. +/// Used by `msl-out`, `wgsl-out`, `hlsl-out`. #[allow(dead_code)] fn vector_size_str(size: crate::VectorSize) -> &'static str { match size { diff --git a/tests/out/empty.Compute.hlsl b/tests/out/empty.Compute.hlsl deleted file mode 100644 index e79d36b9c8..0000000000 --- a/tests/out/empty.Compute.hlsl +++ /dev/null @@ -1,5 +0,0 @@ -[numthreads(1, 1, 1)] -void main() -{ - return; -} diff --git a/tests/out/empty.hlsl b/tests/out/empty.hlsl index d1d3c9cb27..a1b3590ecf 100644 --- a/tests/out/empty.hlsl +++ b/tests/out/empty.hlsl @@ -1,5 +1,5 @@ - -void main() +[numthreads(1, 1, 1)] +void comp_main() { return; } diff --git a/tests/out/empty.hlsl.config b/tests/out/empty.hlsl.config new file mode 100644 index 0000000000..f20f94588f --- /dev/null +++ b/tests/out/empty.hlsl.config @@ -0,0 +1,2 @@ +compute=cs_5_0 +compute_name=comp_main diff --git a/tests/out/quad.hlsl b/tests/out/quad.hlsl new file mode 100644 index 0000000000..c678e3e8f9 --- /dev/null +++ b/tests/out/quad.hlsl @@ -0,0 +1,34 @@ +static const float c_scale = 1.2; + +Texture2D u_texture : register(t0); +SamplerState u_sampler : register(s1); + +struct VertexOutput { + float2 uv : LOC0; + float4 position : SV_Position; +}; + +struct VertexInput { + float2 pos1 : LOC0; + float2 uv3 : LOC1; +}; + +struct FragmentInput { + float2 uv4 : LOC0; +}; + +VertexOutput vert_main(VertexInput vertexinput) +{ + const VertexOutput vertexoutput1 = { vertexinput.uv3, float4((1.2 * vertexinput.pos1), 0.0, 1.0) }; + return vertexoutput1; +} + +float4 frag_main(FragmentInput fragmentinput) : SV_Target0 +{ + float4 color = u_texture.Sample(u_sampler, fragmentinput.uv4); + if ((color.w == 0.0)) { + discard; + } + float4 premultiplied = (color.w * color); + return premultiplied; +} diff --git a/tests/out/quad.hlsl.config b/tests/out/quad.hlsl.config new file mode 100644 index 0000000000..a11240919f --- /dev/null +++ b/tests/out/quad.hlsl.config @@ -0,0 +1,4 @@ +vertex=vs_5_0 +vertex_name=vert_main +fragment=ps_5_0 +fragment_name=frag_main diff --git a/tests/snapshots.rs b/tests/snapshots.rs index 742bb977e5..d3076391b1 100644 --- a/tests/snapshots.rs +++ b/tests/snapshots.rs @@ -107,9 +107,7 @@ fn check_targets(module: &naga::Module, name: &str, targets: Targets) { #[cfg(feature = "hlsl-out")] { if targets.contains(Targets::HLSL) { - for ep in module.entry_points.iter() { - check_output_hlsl(module, &info, &dest, ep.stage); - } + check_output_hlsl(module, &info, &dest); } } #[cfg(feature = "wgsl-out")] @@ -221,18 +219,28 @@ fn check_output_glsl( } #[cfg(feature = "hlsl-out")] -fn check_output_hlsl( - module: &naga::Module, - info: &naga::valid::ModuleInfo, - destination: &PathBuf, - stage: naga::ShaderStage, -) { +fn check_output_hlsl(module: &naga::Module, info: &naga::valid::ModuleInfo, destination: &PathBuf) { use naga::back::hlsl; + let options = hlsl::Options::default(); + let string = hlsl::write_string(module, info, &options).unwrap(); - let string = hlsl::write_string(module, info, hlsl::ShaderModel::default()).unwrap(); + fs::write(destination.with_extension("hlsl"), string).unwrap(); - let ext = format!("{:?}.hlsl", stage); - fs::write(destination.with_extension(&ext), string).unwrap(); + // We need a config file for validation script + // This file contains an info about profiles (shader stages) contains inside generated shader + // This info will be passed to dxc + let mut config_str = String::from(""); + for ep in module.entry_points.iter() { + let (stage_str, profile, ep_name) = match ep.stage { + naga::ShaderStage::Vertex => ("vertex", "vs_5_0", &options.vertex_entry_point_name), + naga::ShaderStage::Fragment => { + ("fragment", "ps_5_0", &options.fragment_entry_point_name) + } + naga::ShaderStage::Compute => ("compute", "cs_5_0", &options.compute_entry_point_name), + }; + config_str = format!("{}{}={}\n{}_name={}\n", config_str, stage_str, profile, stage_str, ep_name); + } + fs::write(destination.with_extension("hlsl.config"), config_str).unwrap(); } #[cfg(feature = "wgsl-out")] @@ -255,7 +263,12 @@ fn convert_wgsl() { ), ( "quad", - Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::DOT | Targets::WGSL, + Targets::SPIRV + | Targets::METAL + | Targets::GLSL + | Targets::DOT + | Targets::HLSL + | Targets::WGSL, ), ( "boids",