From 02c74b50021379092a9d4567fd11aeb0879c5405 Mon Sep 17 00:00:00 2001 From: Igor Shaposhnik Date: Mon, 23 Aug 2021 05:30:22 +0300 Subject: [PATCH] [hlsl-out] Implement switch statement (#1265) * [hlsl-out] Implement switch statement * [hlsl-out] Implement switch statement * Add switch tests to control-flow snapshot --- src/back/glsl/mod.rs | 5 -- src/back/hlsl/writer.rs | 57 ++++++++++++++++++- tests/in/control-flow.wgsl | 26 +++++++++ tests/out/glsl/control-flow.main.Compute.glsl | 22 ++++++- tests/out/hlsl/control-flow.hlsl | 35 +++++++++++- tests/out/msl/control-flow.msl | 31 +++++++++- tests/out/wgsl/control-flow.wgsl | 30 +++++++++- tests/snapshots.rs | 4 +- 8 files changed, 198 insertions(+), 12 deletions(-) diff --git a/src/back/glsl/mod.rs b/src/back/glsl/mod.rs index 0558c90963..a0f94cee28 100644 --- a/src/back/glsl/mod.rs +++ b/src/back/glsl/mod.rs @@ -1441,11 +1441,6 @@ impl<'a, W: Write> Writer<'a, W> { for sta in case.body.iter() { self.write_stmt(sta, ctx, indent + 2)?; } - - // Write `break;` if the block isn't fallthrough - if !case.fall_through { - writeln!(self.out, "{}break;", INDENT.repeat(indent + 2))?; - } } // Only write the default block if the block isn't empty diff --git a/src/back/hlsl/writer.rs b/src/back/hlsl/writer.rs index 108d2b6c64..7a3f39f50b 100644 --- a/src/back/hlsl/writer.rs +++ b/src/back/hlsl/writer.rs @@ -1204,8 +1204,61 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { self.temp_access_chain = chain; self.named_expressions.insert(result, res_name); } - Statement::Switch { .. } => { - return Err(Error::Unimplemented(format!("write_stmt {:?}", stmt))) + Statement::Switch { + selector, + ref cases, + ref default, + } => { + // Start the switch + write!(self.out, "{}", INDENT.repeat(indent))?; + write!(self.out, "switch(")?; + self.write_expr(module, selector, func_ctx)?; + writeln!(self.out, ") {{")?; + + // Write all cases + let indent_str_1 = INDENT.repeat(indent + 1); + let indent_str_2 = INDENT.repeat(indent + 2); + + for case in cases { + writeln!(self.out, "{}case {}: {{", &indent_str_1, case.value)?; + + if case.fall_through { + // Generate each fallthrough case statement in a new block. This is done to + // prevent symbol collision of variables declared in these cases statements. + writeln!(self.out, "{}/* fallthrough */", &indent_str_2)?; + writeln!(self.out, "{}{{", &indent_str_2)?; + } + for sta in case.body.iter() { + self.write_stmt( + module, + sta, + func_ctx, + indent + 2 + usize::from(case.fall_through), + )?; + } + + if case.fall_through { + writeln!(self.out, "{}}}", &indent_str_2)?; + } else { + writeln!(self.out, "{}break;", &indent_str_2)?; + } + + writeln!(self.out, "{}}}", &indent_str_1)?; + } + + // Only write the default block if the block isn't empty + // Writing default without a block is valid but it's more readable this way + if !default.is_empty() { + writeln!(self.out, "{}default: {{", &indent_str_1)?; + + for sta in default { + self.write_stmt(module, sta, func_ctx, indent + 2)?; + } + + writeln!(self.out, "{}}}", &indent_str_1)?; + } + + writeln!(self.out, "{}}}", INDENT.repeat(indent))? } } diff --git a/tests/in/control-flow.wgsl b/tests/in/control-flow.wgsl index 893546e430..6cb0f3619b 100644 --- a/tests/in/control-flow.wgsl +++ b/tests/in/control-flow.wgsl @@ -3,4 +3,30 @@ fn main([[builtin(global_invocation_id)]] global_id: vec3) { //TODO: execution-only barrier? storageBarrier(); workgroupBarrier(); + + var pos: i32; + // switch without cases + switch (1) { + default: { + pos = 1; + } + } + + switch (pos) { + case 1: { + pos = 0; + break; + } + case 2: { + pos = 1; + } + case 3: { + pos = 2; + fallthrough; + } + case 4: {} + default: { + pos = 3; + } + } } diff --git a/tests/out/glsl/control-flow.main.Compute.glsl b/tests/out/glsl/control-flow.main.Compute.glsl index d2bc4571f8..fd12e43461 100644 --- a/tests/out/glsl/control-flow.main.Compute.glsl +++ b/tests/out/glsl/control-flow.main.Compute.glsl @@ -8,8 +8,28 @@ layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; void main() { uvec3 global_id = gl_GlobalInvocationID; + int pos = 0; groupMemoryBarrier(); groupMemoryBarrier(); - return; + switch(1) { + default: + pos = 1; + } + int _e4 = pos; + switch(_e4) { + case 1: + pos = 0; + break; + case 2: + pos = 1; + return; + case 3: + pos = 2; + case 4: + return; + default: + pos = 3; + return; + } } diff --git a/tests/out/hlsl/control-flow.hlsl b/tests/out/hlsl/control-flow.hlsl index 2d6810aab4..b4e5192415 100644 --- a/tests/out/hlsl/control-flow.hlsl +++ b/tests/out/hlsl/control-flow.hlsl @@ -6,7 +6,40 @@ struct ComputeInput_main { [numthreads(1, 1, 1)] void main(ComputeInput_main computeinput_main) { + int pos = (int)0; + DeviceMemoryBarrierWithGroupSync(); GroupMemoryBarrierWithGroupSync(); - return; + switch(1) { + default: { + pos = 1; + } + } + int _expr4 = pos; + switch(_expr4) { + case 1: { + pos = 0; + break; + break; + } + case 2: { + pos = 1; + return; + break; + } + case 3: { + /* fallthrough */ + { + pos = 2; + } + } + case 4: { + return; + break; + } + default: { + pos = 3; + return; + } + } } diff --git a/tests/out/msl/control-flow.msl b/tests/out/msl/control-flow.msl index ef3b1eee2f..bd9b8d27d9 100644 --- a/tests/out/msl/control-flow.msl +++ b/tests/out/msl/control-flow.msl @@ -8,7 +8,36 @@ struct main1Input { kernel void main1( metal::uint3 global_id [[thread_position_in_grid]] ) { + int pos; metal::threadgroup_barrier(metal::mem_flags::mem_device); metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); - return; + switch(1) { + default: { + pos = 1; + } + } + int _e4 = pos; + switch(_e4) { + case 1: { + pos = 0; + break; + break; + } + case 2: { + pos = 1; + return; + break; + } + case 3: { + pos = 2; + } + case 4: { + return; + break; + } + default: { + pos = 3; + return; + } + } } diff --git a/tests/out/wgsl/control-flow.wgsl b/tests/out/wgsl/control-flow.wgsl index 31b2fa2cb9..3eb55cfd80 100644 --- a/tests/out/wgsl/control-flow.wgsl +++ b/tests/out/wgsl/control-flow.wgsl @@ -1,6 +1,34 @@ [[stage(compute), workgroup_size(1, 1, 1)]] fn main([[builtin(global_invocation_id)]] global_id: vec3) { + var pos: i32; + storageBarrier(); workgroupBarrier(); - return; + switch(1) { + default: { + pos = 1; + } + } + let _e4: i32 = pos; + switch(_e4) { + case 1: { + pos = 0; + break; + } + case 2: { + pos = 1; + return; + } + case 3: { + pos = 2; + fallthrough; + } + case 4: { + return; + } + default: { + pos = 3; + return; + } + } } diff --git a/tests/snapshots.rs b/tests/snapshots.rs index 191ef8b622..3bb1f1798e 100644 --- a/tests/snapshots.rs +++ b/tests/snapshots.rs @@ -425,7 +425,9 @@ fn convert_wgsl() { ), ( "control-flow", - Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL, + // TODO: SPIRV https://github.com/gfx-rs/naga/issues/1017 + //Targets::SPIRV | + Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL, ), ( "standard",