[hlsl-out] Implement switch statement (#1265)

* [hlsl-out] Implement switch statement

* [hlsl-out] Implement switch statement

* Add switch tests to control-flow snapshot
This commit is contained in:
Igor Shaposhnik
2021-08-23 05:30:22 +03:00
committed by Dzmitry Malyshau
parent 7d88637bbf
commit 02c74b5002
8 changed files with 198 additions and 12 deletions

View File

@@ -1441,11 +1441,6 @@ impl<'a, W: Write> Writer<'a, W> {
for sta in case.body.iter() {
self.write_stmt(sta, ctx, indent + 2)?;
}
// Write `break;` if the block isn't fallthrough
if !case.fall_through {
writeln!(self.out, "{}break;", INDENT.repeat(indent + 2))?;
}
}
// Only write the default block if the block isn't empty

View File

@@ -1204,8 +1204,61 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
self.temp_access_chain = chain;
self.named_expressions.insert(result, res_name);
}
Statement::Switch { .. } => {
return Err(Error::Unimplemented(format!("write_stmt {:?}", stmt)))
Statement::Switch {
selector,
ref cases,
ref default,
} => {
// Start the switch
write!(self.out, "{}", INDENT.repeat(indent))?;
write!(self.out, "switch(")?;
self.write_expr(module, selector, func_ctx)?;
writeln!(self.out, ") {{")?;
// Write all cases
let indent_str_1 = INDENT.repeat(indent + 1);
let indent_str_2 = INDENT.repeat(indent + 2);
for case in cases {
writeln!(self.out, "{}case {}: {{", &indent_str_1, case.value)?;
if case.fall_through {
// Generate each fallthrough case statement in a new block. This is done to
// prevent symbol collision of variables declared in these cases statements.
writeln!(self.out, "{}/* fallthrough */", &indent_str_2)?;
writeln!(self.out, "{}{{", &indent_str_2)?;
}
for sta in case.body.iter() {
self.write_stmt(
module,
sta,
func_ctx,
indent + 2 + usize::from(case.fall_through),
)?;
}
if case.fall_through {
writeln!(self.out, "{}}}", &indent_str_2)?;
} else {
writeln!(self.out, "{}break;", &indent_str_2)?;
}
writeln!(self.out, "{}}}", &indent_str_1)?;
}
// Only write the default block if the block isn't empty
// Writing default without a block is valid but it's more readable this way
if !default.is_empty() {
writeln!(self.out, "{}default: {{", &indent_str_1)?;
for sta in default {
self.write_stmt(module, sta, func_ctx, indent + 2)?;
}
writeln!(self.out, "{}}}", &indent_str_1)?;
}
writeln!(self.out, "{}}}", INDENT.repeat(indent))?
}
}

View File

@@ -3,4 +3,30 @@ fn main([[builtin(global_invocation_id)]] global_id: vec3<u32>) {
//TODO: execution-only barrier?
storageBarrier();
workgroupBarrier();
var pos: i32;
// switch without cases
switch (1) {
default: {
pos = 1;
}
}
switch (pos) {
case 1: {
pos = 0;
break;
}
case 2: {
pos = 1;
}
case 3: {
pos = 2;
fallthrough;
}
case 4: {}
default: {
pos = 3;
}
}
}

View File

@@ -8,8 +8,28 @@ layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
void main() {
uvec3 global_id = gl_GlobalInvocationID;
int pos = 0;
groupMemoryBarrier();
groupMemoryBarrier();
return;
switch(1) {
default:
pos = 1;
}
int _e4 = pos;
switch(_e4) {
case 1:
pos = 0;
break;
case 2:
pos = 1;
return;
case 3:
pos = 2;
case 4:
return;
default:
pos = 3;
return;
}
}

View File

@@ -6,7 +6,40 @@ struct ComputeInput_main {
[numthreads(1, 1, 1)]
void main(ComputeInput_main computeinput_main)
{
int pos = (int)0;
DeviceMemoryBarrierWithGroupSync();
GroupMemoryBarrierWithGroupSync();
return;
switch(1) {
default: {
pos = 1;
}
}
int _expr4 = pos;
switch(_expr4) {
case 1: {
pos = 0;
break;
break;
}
case 2: {
pos = 1;
return;
break;
}
case 3: {
/* fallthrough */
{
pos = 2;
}
}
case 4: {
return;
break;
}
default: {
pos = 3;
return;
}
}
}

View File

@@ -8,7 +8,36 @@ struct main1Input {
kernel void main1(
metal::uint3 global_id [[thread_position_in_grid]]
) {
int pos;
metal::threadgroup_barrier(metal::mem_flags::mem_device);
metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);
return;
switch(1) {
default: {
pos = 1;
}
}
int _e4 = pos;
switch(_e4) {
case 1: {
pos = 0;
break;
break;
}
case 2: {
pos = 1;
return;
break;
}
case 3: {
pos = 2;
}
case 4: {
return;
break;
}
default: {
pos = 3;
return;
}
}
}

View File

@@ -1,6 +1,34 @@
[[stage(compute), workgroup_size(1, 1, 1)]]
fn main([[builtin(global_invocation_id)]] global_id: vec3<u32>) {
var pos: i32;
storageBarrier();
workgroupBarrier();
return;
switch(1) {
default: {
pos = 1;
}
}
let _e4: i32 = pos;
switch(_e4) {
case 1: {
pos = 0;
break;
}
case 2: {
pos = 1;
return;
}
case 3: {
pos = 2;
fallthrough;
}
case 4: {
return;
}
default: {
pos = 3;
return;
}
}
}

View File

@@ -425,7 +425,9 @@ fn convert_wgsl() {
),
(
"control-flow",
Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL,
// TODO: SPIRV https://github.com/gfx-rs/naga/issues/1017
//Targets::SPIRV |
Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL,
),
(
"standard",