diff --git a/src/back/wgsl/mod.rs b/src/back/wgsl/mod.rs index 1ca762af9c..6408b51251 100644 --- a/src/back/wgsl/mod.rs +++ b/src/back/wgsl/mod.rs @@ -13,6 +13,8 @@ pub enum Error { Custom(String), #[error("{0}")] Unimplemented(String), // TODO: Error used only during development + #[error("Unsupported math function: {0:?}")] + UnsupportedMathFunction(crate::MathFunction), } pub fn write_string( diff --git a/src/back/wgsl/writer.rs b/src/back/wgsl/writer.rs index 5f634badcd..4d7701d652 100644 --- a/src/back/wgsl/writer.rs +++ b/src/back/wgsl/writer.rs @@ -61,6 +61,7 @@ pub struct Writer { names: FastHashMap, namer: Namer, named_expressions: BitSet, + ep_results: Vec<(ShaderStage, Handle)>, } impl Writer { @@ -70,6 +71,7 @@ impl Writer { names: FastHashMap::default(), namer: Namer::default(), named_expressions: BitSet::new(), + ep_results: vec![], } } @@ -82,6 +84,13 @@ impl Writer { pub fn write(&mut self, module: &Module, info: &ModuleInfo) -> BackendResult { self.reset(module); + // Save all ep result types + for (_, ep) in module.entry_points.iter().enumerate() { + if let Some(ref result) = ep.function.result { + self.ep_results.push((ep.stage, result.ty)); + } + } + // Write all structs for (handle, ty) in module.types.iter() { if let TypeInner::Struct { @@ -173,6 +182,29 @@ impl Writer { Ok(()) } + /// Helper method used to write stuct name + /// + /// # Notes + /// Adds no trailing or leading whitespace + fn write_struct_name(&mut self, module: &Module, handle: Handle) -> BackendResult { + if module.types[handle].name.is_none() { + if let Some(&(stage, _)) = self.ep_results.iter().find(|&&(_, ty)| ty == handle) { + let name = match stage { + ShaderStage::Compute => "ComputeOutput", + ShaderStage::Fragment => "FragmentOutput", + ShaderStage::Vertex => "VertexOutput", + }; + + write!(self.out, "{}", name)?; + return Ok(()); + } + } + + write!(self.out, "{}", self.names[&NameKey::Type(handle)])?; + + Ok(()) + } + /// Helper method used to write structs /// https://gpuweb.github.io/gpuweb/wgsl/#functions /// @@ -222,14 +254,11 @@ impl Writer { // Write function return type if let Some(ref result) = func.result { + write!(self.out, " -> ")?; if let Some(ref binding) = result.binding { - write!(self.out, " -> ")?; self.write_attributes(&map_binding_to_attribute(binding), true)?; - self.write_type(module, result.ty)?; - } else { - let struct_name = &self.names[&NameKey::Type(result.ty)].clone(); - write!(self.out, " -> {}", struct_name)?; } + self.write_type(module, result.ty)?; } write!(self.out, " {{")?; @@ -383,8 +412,10 @@ impl Writer { self.write_attributes(&[Attribute::Block], false)?; writeln!(self.out)?; } - let name = &self.names[&NameKey::Type(handle)].clone(); - write!(self.out, "struct {} {{", name)?; + + write!(self.out, "struct ")?; + self.write_struct_name(module, handle)?; + write!(self.out, " {{")?; writeln!(self.out)?; for (index, member) in members.iter().enumerate() { // Skip struct member with unsupported built in @@ -431,12 +462,7 @@ impl Writer { fn write_type(&mut self, module: &Module, ty: Handle) -> BackendResult { let inner = &module.types[ty].inner; match *inner { - TypeInner::Struct { .. } => { - // Get the struct name - let name = &self.names[&NameKey::Type(ty)]; - write!(self.out, "{}", name)?; - return Ok(()); - } + TypeInner::Struct { .. } => self.write_struct_name(module, ty)?, ref other => self.write_value_type(module, other)?, } @@ -1010,13 +1036,60 @@ impl Writer { use crate::MathFunction as Mf; let fun_name = match fun { + Mf::Abs => "abs", + Mf::Min => "min", + Mf::Max => "max", + Mf::Clamp => "clamp", + // trigonometry + Mf::Cos => "cos", + Mf::Cosh => "cosh", + Mf::Sin => "sin", + Mf::Sinh => "sinh", + Mf::Tan => "tan", + Mf::Tanh => "tanh", + Mf::Acos => "acos", + Mf::Asin => "asin", + Mf::Atan => "atan", + Mf::Atan2 => "atan2", + // decomposition + Mf::Ceil => "ceil", + Mf::Floor => "floor", + Mf::Round => "round", + Mf::Fract => "fract", + Mf::Trunc => "trunc", + Mf::Modf => "modf", + Mf::Frexp => "frexp", + Mf::Ldexp => "ldexp", + // exponent + Mf::Exp => "exp", + Mf::Exp2 => "exp2", + Mf::Log => "log", + Mf::Log2 => "log2", + Mf::Pow => "pow", + // geometry + Mf::Dot => "dot", + Mf::Outer => "outerProduct", + Mf::Cross => "cross", + Mf::Distance => "distance", Mf::Length => "length", + Mf::Normalize => "normalize", + Mf::FaceForward => "faceForward", + Mf::Reflect => "reflect", + // computational + Mf::Sign => "sign", + Mf::Fma => "fma", Mf::Mix => "mix", + Mf::Step => "step", + Mf::SmoothStep => "smoothStep", + Mf::Sqrt => "sqrt", + Mf::InverseSqrt => "inverseSqrt", + Mf::Transpose => "transpose", + Mf::Determinant => "determinant", + // bits + Mf::CountOneBits => "countOneBits", + Mf::ReverseBits => "reverseBits", _ => { - return Err(Error::Unimplemented(format!( - "write_expr Math func {:?}", - fun - ))); + return Err(Error::UnsupportedMathFunction(fun)); } }; diff --git a/tests/out/quad-vert.wgsl b/tests/out/quad-vert.wgsl index 66a173b4cf..579fe73e58 100644 --- a/tests/out/quad-vert.wgsl +++ b/tests/out/quad-vert.wgsl @@ -3,7 +3,7 @@ struct gl_PerVertex { [[builtin(position)]] gl_Position: vec4; }; -struct type10 { +struct VertexOutput { [[location(0), interpolate(perspective)]] member: vec2; [[builtin(position)]] gl_Position1: vec4; }; @@ -21,9 +21,9 @@ fn main() { } [[stage(vertex)]] -fn main1([[location(1)]] a_uv1: vec2, [[location(0)]] a_pos1: vec2) -> type10 { +fn main1([[location(1)]] a_uv1: vec2, [[location(0)]] a_pos1: vec2) -> VertexOutput { a_uv = a_uv1; a_pos = a_pos1; main(); - return type10(v_uv, perVertexStruct.gl_Position); + return VertexOutput(v_uv, perVertexStruct.gl_Position); }