diff --git a/src/back/wgsl/writer.rs b/src/back/wgsl/writer.rs index 817fa78b0a..34dfba76a1 100644 --- a/src/back/wgsl/writer.rs +++ b/src/back/wgsl/writer.rs @@ -880,25 +880,47 @@ impl Writer { }; let l2 = level.next(); - if !cases.is_empty() { - for case in cases { - match case.value { - crate::SwitchValue::Integer(value) => { - writeln!(self.out, "{}case {}{}: {{", l2, value, type_postfix)?; + let mut new_case = true; + for case in cases { + if case.fall_through && !case.body.is_empty() { + // TODO: we could do the same workaround as we did for the HLSL backend + return Err(Error::Unimplemented( + "fall-through switch case block".into(), + )); + } + + match case.value { + crate::SwitchValue::Integer(value) => { + if new_case { + write!(self.out, "{}case ", l2)?; } - crate::SwitchValue::Default => { - writeln!(self.out, "{}default: {{", l2)?; + write!(self.out, "{}{}", value, type_postfix)?; + } + crate::SwitchValue::Default => { + if new_case { + if case.fall_through { + write!(self.out, "{}case ", l2)?; + } else { + write!(self.out, "{}", l2)?; + } } + write!(self.out, "default")?; } + } - for sta in case.body.iter() { - self.write_stmt(module, sta, func_ctx, l2.next())?; - } + new_case = !case.fall_through; - if case.fall_through { - writeln!(self.out, "{}fallthrough;", l2.next())?; - } + if case.fall_through { + write!(self.out, ", ")?; + } else { + writeln!(self.out, ": {{")?; + } + for sta in case.body.iter() { + self.write_stmt(module, sta, func_ctx, l2.next())?; + } + + if !case.fall_through { writeln!(self.out, "{}}}", l2)?; } } diff --git a/src/front/wgsl/mod.rs b/src/front/wgsl/mod.rs index 6117d0066a..4e2733f69e 100644 --- a/src/front/wgsl/mod.rs +++ b/src/front/wgsl/mod.rs @@ -1513,13 +1513,20 @@ impl Parser { lexer.span_from(initial) } - fn parse_switch_value<'a>(lexer: &mut Lexer<'a>, uint: bool) -> Result> { - let token_span = lexer.next(); - match token_span.0 { - Token::Number(Ok(Number::U32(num))) if uint => Ok(num as i32), - Token::Number(Ok(Number::I32(num))) if !uint => Ok(num), - Token::Number(Err(e)) => Err(Error::BadNumber(token_span.1, e)), - _ => Err(Error::Unexpected(token_span.1, ExpectedToken::Integer)), + fn parse_switch_value<'a>( + lexer: &mut Lexer<'a>, + uint: bool, + ) -> Result> { + match lexer.next() { + (Token::Word("default"), _) => Ok(crate::SwitchValue::Default), + (Token::Number(Ok(Number::U32(num))), _) if uint => { + Ok(crate::SwitchValue::Integer(num as i32)) + } + (Token::Number(Ok(Number::I32(num))), _) if !uint => { + Ok(crate::SwitchValue::Integer(num)) + } + (Token::Number(Err(e)), span) => Err(Error::BadNumber(span, e)), + (_, span) => Err(Error::Unexpected(span, ExpectedToken::Integer)), } } @@ -3576,34 +3583,6 @@ impl Parser { Ok(()) } - fn parse_switch_case_body<'a, 'out>( - &mut self, - lexer: &mut Lexer<'a>, - mut context: StatementContext<'a, '_, 'out>, - ) -> Result<(bool, crate::Block), Error<'a>> { - let mut body = crate::Block::new(); - // Push a new lexical scope for the switch case body - context.symbol_table.push_scope(); - - lexer.expect(Token::Paren('{'))?; - let fall_through = loop { - // default statements - if lexer.skip(Token::Word("fallthrough")) { - lexer.expect(Token::Separator(';'))?; - lexer.expect(Token::Paren('}'))?; - break true; - } - if lexer.skip(Token::Paren('}')) { - break false; - } - self.parse_statement(lexer, context.reborrow(), &mut body, false)?; - }; - // Pop the switch case body lexical scope - context.symbol_table.pop_scope(); - - Ok((fall_through, body)) - } - fn parse_statement<'a, 'out>( &mut self, lexer: &mut Lexer<'a>, @@ -3928,37 +3907,34 @@ impl Parser { break value; } cases.push(crate::SwitchCase { - value: crate::SwitchValue::Integer(value), + value, body: crate::Block::new(), fall_through: true, }); }; - let (fall_through, body) = - self.parse_switch_case_body(lexer, context.reborrow())?; - + let body = + self.parse_block(lexer, context.reborrow(), false)?; cases.push(crate::SwitchCase { - value: crate::SwitchValue::Integer(value), + value, body, - fall_through, + fall_through: false, }); } (Token::Word("default"), _) => { lexer.skip(Token::Separator(':')); - let (fall_through, body) = - self.parse_switch_case_body(lexer, context.reborrow())?; + + let body = + self.parse_block(lexer, context.reborrow(), false)?; cases.push(crate::SwitchCase { value: crate::SwitchValue::Default, body, - fall_through, + fall_through: false, }); } (Token::Paren('}'), _) => break, - other => { - return Err(Error::Unexpected( - other.1, - ExpectedToken::SwitchItem, - )) + (_, span) => { + return Err(Error::Unexpected(span, ExpectedToken::SwitchItem)) } } } diff --git a/src/front/wgsl/tests.rs b/src/front/wgsl/tests.rs index 33fc541acb..6aaa505d4d 100644 --- a/src/front/wgsl/tests.rs +++ b/src/front/wgsl/tests.rs @@ -249,8 +249,7 @@ fn parse_switch() { var pos: f32; switch (3) { case 0, 1: { pos = 0.0; } - case 2: { pos = 1.0; fallthrough; } - case 3: {} + case 2: { pos = 1.0; } default: { pos = 3.0; } } } @@ -267,8 +266,7 @@ fn parse_switch_optional_colon_in_case() { var pos: f32; switch (3) { case 0, 1 { pos = 0.0; } - case 2 { pos = 1.0; fallthrough; } - case 3 {} + case 2 { pos = 1.0; } default { pos = 3.0; } } } @@ -277,6 +275,23 @@ fn parse_switch_optional_colon_in_case() { .unwrap(); } +#[test] +fn parse_switch_default_in_case() { + parse_str( + " + fn main() { + var pos: f32; + switch (3) { + case 0, 1: { pos = 0.0; } + case 2: {} + case default, 3: { pos = 3.0; } + } + } + ", + ) + .unwrap(); +} + #[test] fn parse_parentheses_switch() { parse_str( diff --git a/tests/in/control-flow.wgsl b/tests/in/control-flow.wgsl index 787742d71d..5a0ef1cbbf 100644 --- a/tests/in/control-flow.wgsl +++ b/tests/in/control-flow.wgsl @@ -22,15 +22,13 @@ fn main(@builtin(global_invocation_id) global_id: vec3) { case 2: { pos = 1; } - case 3: { + case 3, 4: { pos = 2; - fallthrough; } - case 4: { + case 5: { pos = 3; - fallthrough; } - default: { + case default, 6: { pos = 4; } } @@ -54,7 +52,6 @@ fn main(@builtin(global_invocation_id) global_id: vec3) { } case 3: { pos = 2; - fallthrough; } case 4: {} default: { diff --git a/tests/out/glsl/control-flow.main.Compute.glsl b/tests/out/glsl/control-flow.main.Compute.glsl index 5446e36098..a66c83a309 100644 --- a/tests/out/glsl/control-flow.main.Compute.glsl +++ b/tests/out/glsl/control-flow.main.Compute.glsl @@ -56,12 +56,16 @@ void main() { pos = 1; break; case 3: - pos = 2; /* fallthrough */ case 4: + pos = 2; + break; + case 5: pos = 3; - /* fallthrough */ + break; default: + /* fallthrough */ + case 6: pos = 4; break; } @@ -81,7 +85,7 @@ void main() { return; case 3: pos = 2; - /* fallthrough */ + return; case 4: return; default: diff --git a/tests/out/hlsl/control-flow.hlsl b/tests/out/hlsl/control-flow.hlsl index 4b879ef20c..354c552a15 100644 --- a/tests/out/hlsl/control-flow.hlsl +++ b/tests/out/hlsl/control-flow.hlsl @@ -60,27 +60,30 @@ void main(uint3 global_id : SV_DispatchThreadID) break; } case 3: { + { + } { pos = 2; } - { - pos = 3; - } - { - pos = 4; - } break; } case 4: { + pos = 2; + break; + } + case 5: { + pos = 3; + break; + } + default: { { - pos = 3; } { pos = 4; } break; } - default: { + case 6: { pos = 4; break; } @@ -104,12 +107,8 @@ void main(uint3 global_id : SV_DispatchThreadID) return; } case 3: { - { - pos = 2; - } - { - return; - } + pos = 2; + return; } case 4: { return; diff --git a/tests/out/msl/control-flow.msl b/tests/out/msl/control-flow.msl index 4a0b805663..d00b7ec18b 100644 --- a/tests/out/msl/control-flow.msl +++ b/tests/out/msl/control-flow.msl @@ -69,12 +69,18 @@ kernel void main_( break; } case 3: { - pos = 2; } case 4: { + pos = 2; + break; + } + case 5: { pos = 3; + break; } default: { + } + case 6: { pos = 4; break; } @@ -99,6 +105,7 @@ kernel void main_( } case 3: { pos = 2; + return; } case 4: { return; diff --git a/tests/out/spv/control-flow.spvasm b/tests/out/spv/control-flow.spvasm index dfd1b145a2..fa00d2dba9 100644 --- a/tests/out/spv/control-flow.spvasm +++ b/tests/out/spv/control-flow.spvasm @@ -1,7 +1,7 @@ ; SPIR-V ; Version: 1.1 ; Generator: rspirv -; Bound: 69 +; Bound: 71 OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 @@ -92,7 +92,7 @@ OpBranch %50 %50 = OpLabel %52 = OpLoad %4 %37 OpSelectionMerge %53 None -OpSwitch %52 %54 1 %55 2 %56 3 %57 4 %58 +OpSwitch %52 %54 1 %55 2 %56 3 %57 4 %58 5 %59 %55 = OpLabel OpStore %37 %5 OpBranch %53 @@ -100,39 +100,43 @@ OpBranch %53 OpStore %37 %3 OpBranch %53 %57 = OpLabel -OpStore %37 %6 OpBranch %58 %58 = OpLabel +OpStore %37 %6 +OpBranch %53 +%59 = OpLabel OpStore %37 %7 -OpBranch %54 +OpBranch %53 %54 = OpLabel +OpBranch %60 +%60 = OpLabel OpStore %37 %8 OpBranch %53 %53 = OpLabel -OpSelectionMerge %59 None -OpSwitch %9 %60 0 %61 +OpSelectionMerge %61 None +OpSwitch %9 %62 0 %63 +%63 = OpLabel +OpBranch %61 +%62 = OpLabel +OpBranch %61 %61 = OpLabel -OpBranch %59 -%60 = OpLabel -OpBranch %59 -%59 = OpLabel -%62 = OpLoad %4 %37 -OpSelectionMerge %63 None -OpSwitch %62 %64 1 %65 2 %66 3 %67 4 %68 -%65 = OpLabel +%64 = OpLoad %4 %37 +OpSelectionMerge %65 None +OpSwitch %64 %66 1 %67 2 %68 3 %69 4 %70 +%67 = OpLabel OpStore %37 %5 -OpBranch %63 -%66 = OpLabel +OpBranch %65 +%68 = OpLabel OpStore %37 %3 OpReturn -%67 = OpLabel +%69 = OpLabel OpStore %37 %6 -OpBranch %68 -%68 = OpLabel OpReturn -%64 = OpLabel +%70 = OpLabel +OpReturn +%66 = OpLabel OpStore %37 %7 OpReturn -%63 = OpLabel +%65 = OpLabel OpReturn OpFunctionEnd \ No newline at end of file diff --git a/tests/out/wgsl/control-flow.wgsl b/tests/out/wgsl/control-flow.wgsl index ce3911f2b1..2872406d76 100644 --- a/tests/out/wgsl/control-flow.wgsl +++ b/tests/out/wgsl/control-flow.wgsl @@ -50,15 +50,13 @@ fn main(@builtin(global_invocation_id) global_id: vec3) { case 2: { pos = 1; } - case 3: { + case 3, 4: { pos = 2; - fallthrough; } - case 4: { + case 5: { pos = 3; - fallthrough; } - default: { + case default, 6: { pos = 4; } } @@ -80,7 +78,7 @@ fn main(@builtin(global_invocation_id) global_id: vec3) { } case 3: { pos = 2; - fallthrough; + return; } case 4: { return; diff --git a/tests/wgsl-errors.rs b/tests/wgsl-errors.rs index dd3e3dda6c..594d92354d 100644 --- a/tests/wgsl-errors.rs +++ b/tests/wgsl-errors.rs @@ -1370,28 +1370,6 @@ fn select() { } } -#[test] -fn last_case_falltrough() { - check_validation! { - " - fn test_falltrough() { - switch(0) { - default: {} - case 0: { - fallthrough; - } - } - } - ": - Err( - naga::valid::ValidationError::Function { - source: naga::valid::FunctionError::LastCaseFallTrough, - .. - }, - ) - } -} - #[test] fn missing_default_case() { check_validation! {