From dd1d9fe29096d9debcde5b6c0ecad27c01cc552b Mon Sep 17 00:00:00 2001 From: Gordon-F Date: Mon, 26 Apr 2021 02:09:29 +0300 Subject: [PATCH] [wgsl-out] More improvements. Enable quad snapshot testing for wgsl backend --- src/back/wgsl/writer.rs | 119 +++++++++++++++++++++++++++++++++------- tests/out/access.wgsl | 6 ++ tests/out/quad.wgsl | 26 +++++++++ tests/snapshots.rs | 4 +- 4 files changed, 133 insertions(+), 22 deletions(-) create mode 100644 tests/out/access.wgsl create mode 100644 tests/out/quad.wgsl diff --git a/src/back/wgsl/writer.rs b/src/back/wgsl/writer.rs index 991e8f4834..f3cbeb80ba 100644 --- a/src/back/wgsl/writer.rs +++ b/src/back/wgsl/writer.rs @@ -14,9 +14,11 @@ use crate::{ proc::{NameKey, Namer}, StructMember, }; +use bit_set::BitSet; use std::fmt::Write; const INDENT: &str = " "; +const BAKE_PREFIX: &str = "_e"; /// Shorthand result used internally by the backend type BackendResult = Result<(), Error>; @@ -60,6 +62,7 @@ pub struct Writer { out: W, names: FastHashMap, namer: Namer, + named_expressions: BitSet, } impl Writer { @@ -68,6 +71,7 @@ impl Writer { out, names: FastHashMap::default(), namer: Namer::default(), + named_expressions: BitSet::new(), } } @@ -196,6 +200,8 @@ impl Writer { writeln!(self.out, "}}")?; } + self.named_expressions.clear(); + Ok(()) } @@ -272,6 +278,16 @@ impl Writer { } // Write struct member name and type write!(self.out, "{}: ", member.name.as_ref().unwrap())?; + // Write stride attribute for array struct member + if let TypeInner::Array { + base: _, + size: _, + stride, + } = module.types[member.ty].inner + { + self.write_attributes(&[Attribute::Stride(stride)])?; + write!(self.out, " ")?; + } self.write_type(module, member.ty)?; write!(self.out, ";")?; writeln!(self.out)?; @@ -350,14 +366,10 @@ impl Writer { TypeInner::Scalar { kind, .. } => { write!(self.out, "{}", scalar_kind_str(kind))?; } - TypeInner::Array { base, size, stride } => { + TypeInner::Array { base, size, .. } => { // More info https://gpuweb.github.io/gpuweb/wgsl/#array-types // array -- Constant array // array -- Dynamic array - if stride > 0 { - self.write_attributes(&[Attribute::Stride(stride)])?; - write!(self.out, " ")?; - } write!(self.out, "array<")?; match size { ArraySize::Constant(handle) => { @@ -410,21 +422,11 @@ impl Writer { for handle in range.clone() { let min_ref_count = func_ctx.expressions[handle].bake_ref_count(); if min_ref_count <= func_ctx.info[handle].ref_count { - match func_ctx.info[handle].ty { - TypeResolution::Handle(ty_handle) => { - write!(self.out, "{}", INDENT.repeat(indent))?; - self.write_type(module, ty_handle)? - } - TypeResolution::Value(ref inner) => { - //TODO: - //write!(self.out, "{}", INDENT.repeat(indent))?; - //self.write_value_type(module, inner)? - return Err(Error::Unimplemented(format!( - "Emit statement TypeResolution::Value {:?}", - inner - ))); - } - } + write!(self.out, "{}", INDENT.repeat(indent))?; + self.start_baking_expr(handle, &func_ctx)?; + self.write_expr(module, handle, &func_ctx)?; + writeln!(self.out, ";")?; + self.named_expressions.insert(handle.index()); } } } @@ -480,6 +482,36 @@ impl Writer { Ok(()) } + fn start_baking_expr( + &mut self, + handle: Handle, + context: &FunctionCtx, + ) -> BackendResult { + // Write variable name + write!(self.out, "let {}{}: ", BAKE_PREFIX, handle.index())?; + let ty = &context.info[handle].ty; + // Write variable type + match *ty { + TypeResolution::Value(crate::TypeInner::Scalar { kind, .. }) => { + write!(self.out, "{}", scalar_kind_str(kind))?; + } + TypeResolution::Value(crate::TypeInner::Vector { size, kind, .. }) => { + write!( + self.out, + "vec{}<{}>", + vector_size_str(size), + scalar_kind_str(kind), + )?; + } + _ => { + return Err(Error::Unimplemented(format!("start_baking_expr {:?}", ty))); + } + } + + write!(self.out, " = ")?; + Ok(()) + } + /// Helper method to write expressions /// /// # Notes @@ -491,6 +523,12 @@ impl Writer { func_ctx: &FunctionCtx<'_>, ) -> BackendResult { let expression = &func_ctx.expressions[expr]; + + if self.named_expressions.contains(expr.index()) { + write!(self.out, "{}{}", BAKE_PREFIX, expr.index())?; + return Ok(()); + } + match *expression { Expression::Constant(constant) => { self.write_constant(&module.constants[constant], false)? @@ -593,6 +631,47 @@ impl Writer { let name = &self.names[&NameKey::GlobalVariable(handle)]; write!(self.out, "{}", name)?; } + Expression::As { + expr, + kind, + convert: _, //TODO: + } => { + let inner = func_ctx.info[expr].ty.inner_with(&module.types); + let op = match *inner { + TypeInner::Matrix { columns, rows, .. } => { + format!("mat{}x{}", vector_size_str(columns), vector_size_str(rows)) + } + TypeInner::Vector { size, .. } => format!("vec{}", vector_size_str(size)), + _ => { + return Err(Error::Unimplemented(format!( + "write_expr expression::as {:?}", + inner + ))); + } + }; + let scalar = scalar_kind_str(kind); + write!(self.out, "{}<{}>(", op, scalar)?; + self.write_expr(module, expr, func_ctx)?; + write!(self.out, ")")?; + } + Expression::Splat { size, value } => { + let inner = func_ctx.info[value].ty.inner_with(&module.types); + let scalar_kind = match *inner { + crate::TypeInner::Scalar { kind, .. } => kind, + _ => { + return Err(Error::Unimplemented(format!( + "write_expr expression::splat {:?}", + inner + ))); + } + }; + let scalar = scalar_kind_str(scalar_kind); + let size = vector_size_str(size); + + write!(self.out, "vec{}<{}>(", size, scalar)?; + self.write_expr(module, value, func_ctx)?; + write!(self.out, ")")?; + } _ => { return Err(Error::Unimplemented(format!("write_expr {:?}", expression))); } diff --git a/tests/out/access.wgsl b/tests/out/access.wgsl new file mode 100644 index 0000000000..4b1cbc85ec --- /dev/null +++ b/tests/out/access.wgsl @@ -0,0 +1,6 @@ +[[stage(vertex)]] +fn foo([[builtin(vertex_index)]] vi: u32) -> [[builtin(position)]] vec4 { + return vec4(vec4(array(1, 2, 3, 4, 5)[vi])); +} + + diff --git a/tests/out/quad.wgsl b/tests/out/quad.wgsl new file mode 100644 index 0000000000..ba45b48c26 --- /dev/null +++ b/tests/out/quad.wgsl @@ -0,0 +1,26 @@ +let c_scale: f32 = 1.2; + +[[group(0), binding(0)]] var u_texture: texture_2d; + +[[group(0), binding(1)]] var u_sampler: sampler; + +struct VertexOutput { + [[location(0)]] uv: vec2; + [[builtin(position)]] position: vec4; +}; + +[[stage(vertex)]] +fn main([[location(0)]] pos: vec2, [[location(1)]] uv1: vec2) -> VertexOutput { + return VertexOutput(uv1, vec4(c_scale * pos, 0.0, 1.0)); +} + +[[stage(fragment)]] +fn main([[location(0)]] uv2: vec2) -> [[location(0)]] vec4 { + let _e4: vec4 = textureSample(u_texture, u_sampler, uv2); + if (_e4[3] == 0.0) { + discard; + } + return _e4[3] * _e4; +} + + diff --git a/tests/snapshots.rs b/tests/snapshots.rs index ed46091c0b..b82e14b71c 100644 --- a/tests/snapshots.rs +++ b/tests/snapshots.rs @@ -226,7 +226,7 @@ fn convert_wgsl() { ), ( "quad", - Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::DOT, + Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::DOT | Targets::WGSL, ), ("boids", Targets::SPIRV | Targets::METAL), ("skybox", Targets::SPIRV | Targets::METAL | Targets::GLSL), @@ -242,7 +242,7 @@ fn convert_wgsl() { "interpolate", Targets::SPIRV | Targets::METAL | Targets::GLSL, ), - ("access", Targets::SPIRV | Targets::METAL), + ("access", Targets::SPIRV | Targets::METAL | Targets::WGSL), ]; for &(name, targets) in inputs.iter() {