From 91ee407c87638ed81844ac2f71e5cd0cbf3dbb61 Mon Sep 17 00:00:00 2001 From: teoxoy <28601907+teoxoy@users.noreply.github.com> Date: Sat, 14 May 2022 16:46:08 +0200 Subject: [PATCH] [hlsl-out] fix fallthrough in switch statements --- src/back/hlsl/writer.rs | 46 +++-- tests/in/control-flow.wgsl | 7 +- tests/out/glsl/control-flow.main.Compute.glsl | 9 +- tests/out/hlsl/control-flow.hlsl | 24 ++- tests/out/msl/control-flow.msl | 8 +- tests/out/spv/control-flow.spvasm | 192 +++++++++--------- tests/out/wgsl/control-flow.wgsl | 8 +- 7 files changed, 163 insertions(+), 131 deletions(-) diff --git a/src/back/hlsl/writer.rs b/src/back/hlsl/writer.rs index e713c2f5dd..16d21e0798 100644 --- a/src/back/hlsl/writer.rs +++ b/src/back/hlsl/writer.rs @@ -1651,7 +1651,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { let indent_level_1 = level.next(); let indent_level_2 = indent_level_1.next(); - for case in cases { + for (i, case) in cases.iter().enumerate() { match case.value { crate::SwitchValue::Integer(value) => writeln!( self.out, @@ -1663,25 +1663,35 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { } } + // FXC doesn't support fallthrough so we duplicate the body of the following case blocks if case.fall_through { - // Generate each fallthrough case statement in a new block. This is done to - // prevent symbol collision of variables declared in these cases statements. - writeln!(self.out, "{}/* fallthrough */", indent_level_2)?; - writeln!(self.out, "{}{{", indent_level_2)?; - } - for sta in case.body.iter() { - self.write_stmt( - module, - sta, - func_ctx, - back::Level(indent_level_2.0 + usize::from(case.fall_through)), - )?; - } + let curr_len = i + 1; + let end_case_idx = curr_len + + cases + .iter() + .skip(curr_len) + .position(|case| !case.fall_through) + .unwrap(); + let indent_level_3 = indent_level_2.next(); + for case in &cases[i..=end_case_idx] { + writeln!(self.out, "{}{{", indent_level_2)?; + for sta in case.body.iter() { + self.write_stmt(module, sta, func_ctx, indent_level_3)?; + } + writeln!(self.out, "{}}}", indent_level_2)?; + } - if case.fall_through { - writeln!(self.out, "{}}}", indent_level_2)?; - } else if case.body.last().map_or(true, |s| !s.is_terminator()) { - writeln!(self.out, "{}break;", indent_level_2)?; + let last_case = &cases[end_case_idx]; + if last_case.body.last().map_or(true, |s| !s.is_terminator()) { + writeln!(self.out, "{}break;", indent_level_2)?; + } + } else { + for sta in case.body.iter() { + self.write_stmt(module, sta, func_ctx, indent_level_2)?; + } + if case.body.last().map_or(true, |s| !s.is_terminator()) { + writeln!(self.out, "{}break;", indent_level_2)?; + } } writeln!(self.out, "{}}}", indent_level_1)?; diff --git a/tests/in/control-flow.wgsl b/tests/in/control-flow.wgsl index f059eeeb8f..787742d71d 100644 --- a/tests/in/control-flow.wgsl +++ b/tests/in/control-flow.wgsl @@ -26,9 +26,12 @@ fn main(@builtin(global_invocation_id) global_id: vec3) { pos = 2; fallthrough; } - case 4: {} - default: { + case 4: { pos = 3; + fallthrough; + } + default: { + pos = 4; } } diff --git a/tests/out/glsl/control-flow.main.Compute.glsl b/tests/out/glsl/control-flow.main.Compute.glsl index a5e1034133..5446e36098 100644 --- a/tests/out/glsl/control-flow.main.Compute.glsl +++ b/tests/out/glsl/control-flow.main.Compute.glsl @@ -59,9 +59,10 @@ void main() { pos = 2; /* fallthrough */ case 4: - break; - default: pos = 3; + /* fallthrough */ + default: + pos = 4; break; } switch(0u) { @@ -70,8 +71,8 @@ void main() { default: break; } - int _e10 = pos; - switch(_e10) { + int _e11 = pos; + switch(_e11) { case 1: pos = 0; break; diff --git a/tests/out/hlsl/control-flow.hlsl b/tests/out/hlsl/control-flow.hlsl index e65dbc75f2..4b879ef20c 100644 --- a/tests/out/hlsl/control-flow.hlsl +++ b/tests/out/hlsl/control-flow.hlsl @@ -60,16 +60,28 @@ void main(uint3 global_id : SV_DispatchThreadID) break; } case 3: { - /* fallthrough */ { pos = 2; } + { + pos = 3; + } + { + pos = 4; + } + break; } case 4: { + { + pos = 3; + } + { + pos = 4; + } break; } default: { - pos = 3; + pos = 4; break; } } @@ -81,8 +93,8 @@ void main(uint3 global_id : SV_DispatchThreadID) break; } } - int _expr10 = pos; - switch(_expr10) { + int _expr11 = pos; + switch(_expr11) { case 1: { pos = 0; break; @@ -92,10 +104,12 @@ void main(uint3 global_id : SV_DispatchThreadID) return; } case 3: { - /* fallthrough */ { pos = 2; } + { + return; + } } case 4: { return; diff --git a/tests/out/msl/control-flow.msl b/tests/out/msl/control-flow.msl index e35d24d5fc..4a0b805663 100644 --- a/tests/out/msl/control-flow.msl +++ b/tests/out/msl/control-flow.msl @@ -72,10 +72,10 @@ kernel void main_( pos = 2; } case 4: { - break; + pos = 3; } default: { - pos = 3; + pos = 4; break; } } @@ -87,8 +87,8 @@ kernel void main_( break; } } - int _e10 = pos; - switch(_e10) { + int _e11 = pos; + switch(_e11) { case 1: { pos = 0; break; diff --git a/tests/out/spv/control-flow.spvasm b/tests/out/spv/control-flow.spvasm index 93da801035..dfd1b145a2 100644 --- a/tests/out/spv/control-flow.spvasm +++ b/tests/out/spv/control-flow.spvasm @@ -1,136 +1,138 @@ ; SPIR-V ; Version: 1.1 ; Generator: rspirv -; Bound: 68 +; Bound: 69 OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 -OpEntryPoint GLCompute %43 "main" %40 -OpExecutionMode %43 LocalSize 1 1 1 -OpDecorate %40 BuiltIn GlobalInvocationId +OpEntryPoint GLCompute %44 "main" %41 +OpExecutionMode %44 LocalSize 1 1 1 +OpDecorate %41 BuiltIn GlobalInvocationId %2 = OpTypeVoid %4 = OpTypeInt 32 1 %3 = OpConstant %4 1 %5 = OpConstant %4 0 %6 = OpConstant %4 2 %7 = OpConstant %4 3 -%9 = OpTypeInt 32 0 -%8 = OpConstant %9 0 -%10 = OpTypeVector %9 3 -%14 = OpTypeFunction %2 %4 -%20 = OpTypeFunction %2 -%37 = OpTypePointer Function %4 -%38 = OpConstantNull %4 -%41 = OpTypePointer Input %10 -%40 = OpVariable %41 Input -%45 = OpConstant %9 2 -%46 = OpConstant %9 1 -%47 = OpConstant %9 72 -%48 = OpConstant %9 264 -%13 = OpFunction %2 None %14 -%12 = OpFunctionParameter %4 -%11 = OpLabel -OpBranch %15 -%15 = OpLabel -OpSelectionMerge %16 None -OpSwitch %12 %17 -%17 = OpLabel +%8 = OpConstant %4 4 +%10 = OpTypeInt 32 0 +%9 = OpConstant %10 0 +%11 = OpTypeVector %10 3 +%15 = OpTypeFunction %2 %4 +%21 = OpTypeFunction %2 +%38 = OpTypePointer Function %4 +%39 = OpConstantNull %4 +%42 = OpTypePointer Input %11 +%41 = OpVariable %42 Input +%46 = OpConstant %10 2 +%47 = OpConstant %10 1 +%48 = OpConstant %10 72 +%49 = OpConstant %10 264 +%14 = OpFunction %2 None %15 +%13 = OpFunctionParameter %4 +%12 = OpLabel OpBranch %16 %16 = OpLabel +OpSelectionMerge %17 None +OpSwitch %13 %18 +%18 = OpLabel +OpBranch %17 +%17 = OpLabel OpReturn OpFunctionEnd -%19 = OpFunction %2 None %20 -%18 = OpLabel -OpBranch %21 -%21 = OpLabel -OpSelectionMerge %22 None -OpSwitch %5 %23 0 %24 -%24 = OpLabel -OpBranch %22 -%23 = OpLabel +%20 = OpFunction %2 None %21 +%19 = OpLabel OpBranch %22 %22 = OpLabel +OpSelectionMerge %23 None +OpSwitch %5 %24 0 %25 +%25 = OpLabel +OpBranch %23 +%24 = OpLabel +OpBranch %23 +%23 = OpLabel OpReturn OpFunctionEnd -%27 = OpFunction %2 None %14 -%26 = OpFunctionParameter %4 -%25 = OpLabel -OpBranch %28 -%28 = OpLabel +%28 = OpFunction %2 None %15 +%27 = OpFunctionParameter %4 +%26 = OpLabel OpBranch %29 %29 = OpLabel -OpLoopMerge %30 %32 None -OpBranch %31 -%31 = OpLabel -OpSelectionMerge %33 None -OpSwitch %26 %34 1 %35 -%35 = OpLabel +OpBranch %30 +%30 = OpLabel +OpLoopMerge %31 %33 None OpBranch %32 +%32 = OpLabel +OpSelectionMerge %34 None +OpSwitch %27 %35 1 %36 +%36 = OpLabel +OpBranch %33 +%35 = OpLabel +OpBranch %34 %34 = OpLabel OpBranch %33 %33 = OpLabel -OpBranch %32 -%32 = OpLabel -OpBranch %29 -%30 = OpLabel +OpBranch %30 +%31 = OpLabel OpReturn OpFunctionEnd -%43 = OpFunction %2 None %20 -%39 = OpLabel -%36 = OpVariable %37 Function %38 -%42 = OpLoad %10 %40 -OpBranch %44 -%44 = OpLabel -OpControlBarrier %45 %46 %47 -OpControlBarrier %45 %45 %48 -OpSelectionMerge %49 None -OpSwitch %3 %50 +%44 = OpFunction %2 None %21 +%40 = OpLabel +%37 = OpVariable %38 Function %39 +%43 = OpLoad %11 %41 +OpBranch %45 +%45 = OpLabel +OpControlBarrier %46 %47 %48 +OpControlBarrier %46 %46 %49 +OpSelectionMerge %50 None +OpSwitch %3 %51 +%51 = OpLabel +OpStore %37 %3 +OpBranch %50 %50 = OpLabel -OpStore %36 %3 -OpBranch %49 -%49 = OpLabel -%51 = OpLoad %4 %36 -OpSelectionMerge %52 None -OpSwitch %51 %53 1 %54 2 %55 3 %56 4 %57 -%54 = OpLabel -OpStore %36 %5 -OpBranch %52 +%52 = OpLoad %4 %37 +OpSelectionMerge %53 None +OpSwitch %52 %54 1 %55 2 %56 3 %57 4 %58 %55 = OpLabel -OpStore %36 %3 -OpBranch %52 +OpStore %37 %5 +OpBranch %53 %56 = OpLabel -OpStore %36 %6 -OpBranch %57 +OpStore %37 %3 +OpBranch %53 %57 = OpLabel -OpBranch %52 -%53 = OpLabel -OpStore %36 %7 -OpBranch %52 -%52 = OpLabel -OpSelectionMerge %58 None -OpSwitch %8 %59 0 %60 -%60 = OpLabel -OpBranch %58 -%59 = OpLabel +OpStore %37 %6 OpBranch %58 %58 = OpLabel -%61 = OpLoad %4 %36 -OpSelectionMerge %62 None -OpSwitch %61 %63 1 %64 2 %65 3 %66 4 %67 -%64 = OpLabel -OpStore %36 %5 -OpBranch %62 +OpStore %37 %7 +OpBranch %54 +%54 = OpLabel +OpStore %37 %8 +OpBranch %53 +%53 = OpLabel +OpSelectionMerge %59 None +OpSwitch %9 %60 0 %61 +%61 = OpLabel +OpBranch %59 +%60 = OpLabel +OpBranch %59 +%59 = OpLabel +%62 = OpLoad %4 %37 +OpSelectionMerge %63 None +OpSwitch %62 %64 1 %65 2 %66 3 %67 4 %68 %65 = OpLabel -OpStore %36 %3 -OpReturn +OpStore %37 %5 +OpBranch %63 %66 = OpLabel -OpStore %36 %6 -OpBranch %67 +OpStore %37 %3 +OpReturn %67 = OpLabel +OpStore %37 %6 +OpBranch %68 +%68 = OpLabel +OpReturn +%64 = OpLabel +OpStore %37 %7 OpReturn %63 = OpLabel -OpStore %36 %7 -OpReturn -%62 = OpLabel OpReturn OpFunctionEnd \ No newline at end of file diff --git a/tests/out/wgsl/control-flow.wgsl b/tests/out/wgsl/control-flow.wgsl index 55ad20ce2e..ce3911f2b1 100644 --- a/tests/out/wgsl/control-flow.wgsl +++ b/tests/out/wgsl/control-flow.wgsl @@ -55,9 +55,11 @@ fn main(@builtin(global_invocation_id) global_id: vec3) { fallthrough; } case 4: { + pos = 3; + fallthrough; } default: { - pos = 3; + pos = 4; } } switch 0u { @@ -66,8 +68,8 @@ fn main(@builtin(global_invocation_id) global_id: vec3) { default: { } } - let _e10 = pos; - switch _e10 { + let _e11 = pos; + switch _e11 { case 1: { pos = 0; break;