[wgsl-out] More improvements. Enable quad snapshot testing for wgsl backend

This commit is contained in:
Gordon-F
2021-04-26 02:09:29 +03:00
committed by Dzmitry Malyshau
parent 117d729ff8
commit dd1d9fe290
4 changed files with 133 additions and 22 deletions

View File

@@ -14,9 +14,11 @@ use crate::{
proc::{NameKey, Namer},
StructMember,
};
use bit_set::BitSet;
use std::fmt::Write;
const INDENT: &str = " ";
const BAKE_PREFIX: &str = "_e";
/// Shorthand result used internally by the backend
type BackendResult = Result<(), Error>;
@@ -60,6 +62,7 @@ pub struct Writer<W> {
out: W,
names: FastHashMap<NameKey, String>,
namer: Namer,
named_expressions: BitSet,
}
impl<W: Write> Writer<W> {
@@ -68,6 +71,7 @@ impl<W: Write> Writer<W> {
out,
names: FastHashMap::default(),
namer: Namer::default(),
named_expressions: BitSet::new(),
}
}
@@ -196,6 +200,8 @@ impl<W: Write> Writer<W> {
writeln!(self.out, "}}")?;
}
self.named_expressions.clear();
Ok(())
}
@@ -272,6 +278,16 @@ impl<W: Write> Writer<W> {
}
// Write struct member name and type
write!(self.out, "{}: ", member.name.as_ref().unwrap())?;
// Write stride attribute for array struct member
if let TypeInner::Array {
base: _,
size: _,
stride,
} = module.types[member.ty].inner
{
self.write_attributes(&[Attribute::Stride(stride)])?;
write!(self.out, " ")?;
}
self.write_type(module, member.ty)?;
write!(self.out, ";")?;
writeln!(self.out)?;
@@ -350,14 +366,10 @@ impl<W: Write> Writer<W> {
TypeInner::Scalar { kind, .. } => {
write!(self.out, "{}", scalar_kind_str(kind))?;
}
TypeInner::Array { base, size, stride } => {
TypeInner::Array { base, size, .. } => {
// More info https://gpuweb.github.io/gpuweb/wgsl/#array-types
// array<A, 3> -- Constant array
// array<A> -- Dynamic array
if stride > 0 {
self.write_attributes(&[Attribute::Stride(stride)])?;
write!(self.out, " ")?;
}
write!(self.out, "array<")?;
match size {
ArraySize::Constant(handle) => {
@@ -410,21 +422,11 @@ impl<W: Write> Writer<W> {
for handle in range.clone() {
let min_ref_count = func_ctx.expressions[handle].bake_ref_count();
if min_ref_count <= func_ctx.info[handle].ref_count {
match func_ctx.info[handle].ty {
TypeResolution::Handle(ty_handle) => {
write!(self.out, "{}", INDENT.repeat(indent))?;
self.write_type(module, ty_handle)?
}
TypeResolution::Value(ref inner) => {
//TODO:
//write!(self.out, "{}", INDENT.repeat(indent))?;
//self.write_value_type(module, inner)?
return Err(Error::Unimplemented(format!(
"Emit statement TypeResolution::Value {:?}",
inner
)));
}
}
write!(self.out, "{}", INDENT.repeat(indent))?;
self.start_baking_expr(handle, &func_ctx)?;
self.write_expr(module, handle, &func_ctx)?;
writeln!(self.out, ";")?;
self.named_expressions.insert(handle.index());
}
}
}
@@ -480,6 +482,36 @@ impl<W: Write> Writer<W> {
Ok(())
}
fn start_baking_expr(
&mut self,
handle: Handle<Expression>,
context: &FunctionCtx,
) -> BackendResult {
// Write variable name
write!(self.out, "let {}{}: ", BAKE_PREFIX, handle.index())?;
let ty = &context.info[handle].ty;
// Write variable type
match *ty {
TypeResolution::Value(crate::TypeInner::Scalar { kind, .. }) => {
write!(self.out, "{}", scalar_kind_str(kind))?;
}
TypeResolution::Value(crate::TypeInner::Vector { size, kind, .. }) => {
write!(
self.out,
"vec{}<{}>",
vector_size_str(size),
scalar_kind_str(kind),
)?;
}
_ => {
return Err(Error::Unimplemented(format!("start_baking_expr {:?}", ty)));
}
}
write!(self.out, " = ")?;
Ok(())
}
/// Helper method to write expressions
///
/// # Notes
@@ -491,6 +523,12 @@ impl<W: Write> Writer<W> {
func_ctx: &FunctionCtx<'_>,
) -> BackendResult {
let expression = &func_ctx.expressions[expr];
if self.named_expressions.contains(expr.index()) {
write!(self.out, "{}{}", BAKE_PREFIX, expr.index())?;
return Ok(());
}
match *expression {
Expression::Constant(constant) => {
self.write_constant(&module.constants[constant], false)?
@@ -593,6 +631,47 @@ impl<W: Write> Writer<W> {
let name = &self.names[&NameKey::GlobalVariable(handle)];
write!(self.out, "{}", name)?;
}
Expression::As {
expr,
kind,
convert: _, //TODO:
} => {
let inner = func_ctx.info[expr].ty.inner_with(&module.types);
let op = match *inner {
TypeInner::Matrix { columns, rows, .. } => {
format!("mat{}x{}", vector_size_str(columns), vector_size_str(rows))
}
TypeInner::Vector { size, .. } => format!("vec{}", vector_size_str(size)),
_ => {
return Err(Error::Unimplemented(format!(
"write_expr expression::as {:?}",
inner
)));
}
};
let scalar = scalar_kind_str(kind);
write!(self.out, "{}<{}>(", op, scalar)?;
self.write_expr(module, expr, func_ctx)?;
write!(self.out, ")")?;
}
Expression::Splat { size, value } => {
let inner = func_ctx.info[value].ty.inner_with(&module.types);
let scalar_kind = match *inner {
crate::TypeInner::Scalar { kind, .. } => kind,
_ => {
return Err(Error::Unimplemented(format!(
"write_expr expression::splat {:?}",
inner
)));
}
};
let scalar = scalar_kind_str(scalar_kind);
let size = vector_size_str(size);
write!(self.out, "vec{}<{}>(", size, scalar)?;
self.write_expr(module, value, func_ctx)?;
write!(self.out, ")")?;
}
_ => {
return Err(Error::Unimplemented(format!("write_expr {:?}", expression)));
}

6
tests/out/access.wgsl Normal file
View File

@@ -0,0 +1,6 @@
[[stage(vertex)]]
fn foo([[builtin(vertex_index)]] vi: u32) -> [[builtin(position)]] vec4<f32> {
return vec4<f32>(vec4<i32>(array<i32,5>(1, 2, 3, 4, 5)[vi]));
}

26
tests/out/quad.wgsl Normal file
View File

@@ -0,0 +1,26 @@
let c_scale: f32 = 1.2;
[[group(0), binding(0)]] var u_texture: texture_2d<f32>;
[[group(0), binding(1)]] var u_sampler: sampler;
struct VertexOutput {
[[location(0)]] uv: vec2<f32>;
[[builtin(position)]] position: vec4<f32>;
};
[[stage(vertex)]]
fn main([[location(0)]] pos: vec2<f32>, [[location(1)]] uv1: vec2<f32>) -> VertexOutput {
return VertexOutput(uv1, vec4<f32>(c_scale * pos, 0.0, 1.0));
}
[[stage(fragment)]]
fn main([[location(0)]] uv2: vec2<f32>) -> [[location(0)]] vec4<f32> {
let _e4: vec4<f32> = textureSample(u_texture, u_sampler, uv2);
if (_e4[3] == 0.0) {
discard;
}
return _e4[3] * _e4;
}

View File

@@ -226,7 +226,7 @@ fn convert_wgsl() {
),
(
"quad",
Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::DOT,
Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::DOT | Targets::WGSL,
),
("boids", Targets::SPIRV | Targets::METAL),
("skybox", Targets::SPIRV | Targets::METAL | Targets::GLSL),
@@ -242,7 +242,7 @@ fn convert_wgsl() {
"interpolate",
Targets::SPIRV | Targets::METAL | Targets::GLSL,
),
("access", Targets::SPIRV | Targets::METAL),
("access", Targets::SPIRV | Targets::METAL | Targets::WGSL),
];
for &(name, targets) in inputs.iter() {