From d5fc05e8a4ea39aa91a4c9cc2da149eef5a03ba0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Capucho?= Date: Sun, 19 Sep 2021 22:31:38 +0100 Subject: [PATCH] Allow unsigned integers in switch --- src/back/glsl/mod.rs | 9 +- src/back/hlsl/writer.rs | 13 +- src/back/msl/writer.rs | 9 +- src/back/wgsl/writer.rs | 11 +- src/front/glsl/parser/functions.rs | 11 +- src/front/wgsl/mod.rs | 33 +++- src/front/wgsl/number_literals.rs | 30 --- src/lib.rs | 2 +- src/valid/function.rs | 6 +- tests/in/control-flow.wgsl | 6 + tests/out/glsl/control-flow.main.Compute.glsl | 8 +- tests/out/hlsl/control-flow.hlsl | 9 +- tests/out/msl/control-flow.msl | 11 +- tests/out/spv/control-flow.spvasm | 180 +++++++++--------- tests/out/wgsl/control-flow.wgsl | 8 +- 15 files changed, 204 insertions(+), 142 deletions(-) diff --git a/src/back/glsl/mod.rs b/src/back/glsl/mod.rs index 9bba3c2414..4cb694ccd8 100644 --- a/src/back/glsl/mod.rs +++ b/src/back/glsl/mod.rs @@ -1427,11 +1427,18 @@ impl<'a, W: Write> Writer<'a, W> { write!(self.out, "switch(")?; self.write_expr(selector, ctx)?; writeln!(self.out, ") {{")?; + let type_postfix = match *ctx.info[selector].ty.inner_with(&self.module.types) { + crate::TypeInner::Scalar { + kind: crate::ScalarKind::Uint, + .. + } => "u", + _ => "", + }; // Write all cases let l2 = level.next(); for case in cases { - writeln!(self.out, "{}case {}:", l2, case.value)?; + writeln!(self.out, "{}case {}{}:", l2, case.value, type_postfix)?; for sta in case.body.iter() { self.write_stmt(sta, ctx, l2.next())?; diff --git a/src/back/hlsl/writer.rs b/src/back/hlsl/writer.rs index b0b0d1a98d..d48daf1aae 100644 --- a/src/back/hlsl/writer.rs +++ b/src/back/hlsl/writer.rs @@ -1370,13 +1370,24 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { write!(self.out, "switch(")?; self.write_expr(module, selector, func_ctx)?; writeln!(self.out, ") {{")?; + let type_postfix = match *func_ctx.info[selector].ty.inner_with(&module.types) { + crate::TypeInner::Scalar { + kind: crate::ScalarKind::Uint, + .. + } => "u", + _ => "", + }; // Write all cases let indent_level_1 = level.next(); let indent_level_2 = indent_level_1.next(); for case in cases { - writeln!(self.out, "{}case {}: {{", indent_level_1, case.value)?; + writeln!( + self.out, + "{}case {}{}: {{", + indent_level_1, case.value, type_postfix + )?; if case.fall_through { // Generate each fallthrough case statement in a new block. This is done to diff --git a/src/back/msl/writer.rs b/src/back/msl/writer.rs index 0bedf32875..0bf159e545 100644 --- a/src/back/msl/writer.rs +++ b/src/back/msl/writer.rs @@ -1439,10 +1439,17 @@ impl Writer { } => { write!(self.out, "{}switch(", level)?; self.put_expression(selector, &context.expression, true)?; + let type_postfix = match *context.expression.resolve_type(selector) { + crate::TypeInner::Scalar { + kind: crate::ScalarKind::Uint, + .. + } => "u", + _ => "", + }; writeln!(self.out, ") {{")?; let lcase = level.next(); for case in cases.iter() { - writeln!(self.out, "{}case {}: {{", lcase, case.value)?; + writeln!(self.out, "{}case {}{}: {{", lcase, case.value, type_postfix)?; self.put_block(lcase.next(), &case.body, context)?; if !case.fall_through { writeln!(self.out, "{}break;", lcase.next())?; diff --git a/src/back/wgsl/writer.rs b/src/back/wgsl/writer.rs index 5c0b61cb4c..f8956a691b 100644 --- a/src/back/wgsl/writer.rs +++ b/src/back/wgsl/writer.rs @@ -887,6 +887,13 @@ impl Writer { let all_fall_through = cases .iter() .all(|case| case.fall_through && case.body.is_empty()); + let type_postfix = match *func_ctx.info[selector].ty.inner_with(&module.types) { + crate::TypeInner::Scalar { + kind: crate::ScalarKind::Uint, + .. + } => "u", + _ => "", + }; let l2 = level.next(); if !cases.is_empty() { @@ -896,11 +903,11 @@ impl Writer { } if !all_fall_through && case.fall_through && case.body.is_empty() { write_case = false; - write!(self.out, "{}, ", case.value)?; + write!(self.out, "{}{}, ", case.value, type_postfix)?; continue; } else { write_case = true; - writeln!(self.out, "{}: {{", case.value)?; + writeln!(self.out, "{}{}: {{", case.value, type_postfix)?; } for sta in case.body.iter() { diff --git a/src/front/glsl/parser/functions.rs b/src/front/glsl/parser/functions.rs index cfcb2f9899..68001558bb 100644 --- a/src/front/glsl/parser/functions.rs +++ b/src/front/glsl/parser/functions.rs @@ -157,19 +157,12 @@ impl<'source> ParsingContext<'source> { self.expect(parser, TokenValue::LeftParen)?; - let (mut selector, selector_meta) = { + let selector = { let mut stmt = ctx.stmt_ctx(); let expr = self.parse_expression(parser, ctx, &mut stmt, body)?; - ctx.lower_expect(stmt, parser, expr, ExprPos::Rhs, body)? + ctx.lower_expect(stmt, parser, expr, ExprPos::Rhs, body)?.0 }; - if let Some(crate::ScalarKind::Uint) = parser - .resolve_type(ctx, selector, selector_meta)? - .scalar_kind() - { - ctx.conversion(&mut selector, selector_meta, crate::ScalarKind::Sint, 4)? - } - self.expect(parser, TokenValue::RightParen)?; ctx.emit_flush(body); diff --git a/src/front/wgsl/mod.rs b/src/front/wgsl/mod.rs index 28f1e83b72..bc013f1474 100644 --- a/src/front/wgsl/mod.rs +++ b/src/front/wgsl/mod.rs @@ -21,7 +21,7 @@ use self::{ lexer::Lexer, number_literals::{ get_f32_literal, get_i32_literal, get_u32_literal, parse_generic_non_negative_int_literal, - parse_non_negative_sint_literal, parse_sint_literal, + parse_non_negative_sint_literal, }, }; use codespan_reporting::{ @@ -83,6 +83,7 @@ pub enum ExpectedToken<'a> { ty: Option, width: Option, }, + Integer, Constant, /// Expected: constant, parenthesized expression, identifier PrimaryExpression, @@ -218,6 +219,7 @@ impl<'a> Error<'a> { ) } }, + ExpectedToken::Integer => "unsigned/signed integer literal".to_string(), ExpectedToken::Constant => "constant".to_string(), ExpectedToken::PrimaryExpression => "expression".to_string(), ExpectedToken::AttributeSeparator => "attribute separator (',') or an end of the attribute list (']]')".to_string(), @@ -1242,6 +1244,28 @@ impl Parser { }) } + fn parse_switch_value<'a>(lexer: &mut Lexer<'a>, uint: bool) -> Result> { + let token_span = lexer.next(); + let word = match token_span.0 { + Token::Number { value, width, .. } => { + if let Some(width) = width { + if width != 4 { + // Only 32-bit literals supported by the spec and naga for now! + return Err(Error::BadScalarWidth(token_span.1, width)); + } + } + + value + } + _ => return Err(Error::Unexpected(token_span, ExpectedToken::Integer)), + }; + + match uint { + true => get_u32_literal(word, token_span.1).map(|v| v as i32), + false => get_i32_literal(word, token_span.1), + } + } + fn parse_atomic_pointer<'a>( &mut self, lexer: &mut Lexer<'a>, @@ -3425,6 +3449,11 @@ impl Parser { lexer, context.as_expression(block, &mut emitter), )?; + let uint = Some(crate::ScalarKind::Uint) + == context + .as_expression(block, &mut emitter) + .resolve_type(selector)? + .scalar_kind(); lexer.expect(Token::Paren(')'))?; block.extend(emitter.finish(context.expressions)); lexer.expect(Token::Paren('{'))?; @@ -3438,7 +3467,7 @@ impl Parser { // parse a list of values let value = loop { // TODO: Switch statements also allow for floats, bools and unsigned integers. See https://www.w3.org/TR/WGSL/#switch-statement - let value = parse_sint_literal(lexer, 4)?; + let value = Self::parse_switch_value(lexer, uint)?; if lexer.skip(Token::Separator(',')) { if lexer.skip(Token::Separator(':')) { break value; diff --git a/src/front/wgsl/number_literals.rs b/src/front/wgsl/number_literals.rs index 2f61224142..9279989e94 100644 --- a/src/front/wgsl/number_literals.rs +++ b/src/front/wgsl/number_literals.rs @@ -69,36 +69,6 @@ pub fn get_f32_literal(word: &str, span: Span) -> Result> { parsed_val.map_err(|e| Error::BadFloat(span, e)) } -pub(super) fn parse_sint_literal<'a>( - lexer: &mut Lexer<'a>, - width: Bytes, -) -> Result> { - let token_span = lexer.next(); - - if width != 4 { - // Only 32-bit literals supported by the spec and naga for now! - return Err(Error::BadScalarWidth(token_span.1, width)); - } - - match token_span { - ( - Token::Number { - value, - ty: NumberType::Sint, - width: token_width, - }, - span, - ) if token_width.unwrap_or(4) == width => get_i32_literal(value, span), - other => Err(Error::Unexpected( - other, - ExpectedToken::Number { - ty: Some(NumberType::Sint), - width: Some(width), - }, - )), - } -} - pub(super) fn _parse_uint_literal<'a>( lexer: &mut Lexer<'a>, width: Bytes, diff --git a/src/lib.rs b/src/lib.rs index 3c83254854..697264c516 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1260,7 +1260,7 @@ pub use block::Block; pub struct SwitchCase { /// Value, upon which the case is considered true. pub value: i32, - /// Body of the cae. + /// Body of the case. pub body: Block, /// If true, the control flow continues to the next case in the list, /// or default. diff --git a/src/valid/function.rs b/src/valid/function.rs index ff39fcd573..ddfcb60bb8 100644 --- a/src/valid/function.rs +++ b/src/valid/function.rs @@ -90,7 +90,7 @@ pub enum FunctionError { InvalidIfType(Handle), #[error("The `switch` value {0:?} is not an integer scalar")] InvalidSwitchType(Handle), - #[error("Multiple `switch` cases for {0} are present")] + #[error("Multiple `switch` cases for {0:?} are present")] ConflictingSwitchCase(i32), #[error("The pointer {0:?} doesn't relate to a valid destination for a store")] InvalidStorePointer(Handle), @@ -375,6 +375,10 @@ impl super::Validator { ref default, } => { match *context.resolve_type(selector, &self.valid_expression_set)? { + Ti::Scalar { + kind: crate::ScalarKind::Uint, + width: _, + } => {} Ti::Scalar { kind: crate::ScalarKind::Sint, width: _, diff --git a/tests/in/control-flow.wgsl b/tests/in/control-flow.wgsl index 679ad3d8a3..5ca59817c5 100644 --- a/tests/in/control-flow.wgsl +++ b/tests/in/control-flow.wgsl @@ -32,6 +32,12 @@ fn main([[builtin(global_invocation_id)]] global_id: vec3) { } } + // switch with unsigned integer selectors + switch(0u) { + case 0u: { + } + } + // non-empty switch in last-statement-in-function position switch (pos) { case 1: { diff --git a/tests/out/glsl/control-flow.main.Compute.glsl b/tests/out/glsl/control-flow.main.Compute.glsl index e8b566b9d9..4de563e898 100644 --- a/tests/out/glsl/control-flow.main.Compute.glsl +++ b/tests/out/glsl/control-flow.main.Compute.glsl @@ -56,8 +56,12 @@ void main() { default: pos = 3; } - int _e9 = pos; - switch(_e9) { + switch(0u) { + case 0u: + break; + } + int _e10 = pos; + switch(_e10) { case 1: pos = 0; break; diff --git a/tests/out/hlsl/control-flow.hlsl b/tests/out/hlsl/control-flow.hlsl index d425b0991e..5a771342f1 100644 --- a/tests/out/hlsl/control-flow.hlsl +++ b/tests/out/hlsl/control-flow.hlsl @@ -68,8 +68,13 @@ void main(uint3 global_id : SV_DispatchThreadID) pos = 3; } } - int _expr9 = pos; - switch(_expr9) { + switch(0u) { + case 0u: { + break; + } + } + int _expr10 = pos; + switch(_expr10) { case 1: { pos = 0; break; diff --git a/tests/out/msl/control-flow.msl b/tests/out/msl/control-flow.msl index 52f990134d..8662b8b30a 100644 --- a/tests/out/msl/control-flow.msl +++ b/tests/out/msl/control-flow.msl @@ -76,8 +76,15 @@ kernel void main1( pos = 3; } } - int _e9 = pos; - switch(_e9) { + switch(0u) { + case 0u: { + break; + } + default: { + } + } + int _e10 = pos; + switch(_e10) { case 1: { pos = 0; break; diff --git a/tests/out/spv/control-flow.spvasm b/tests/out/spv/control-flow.spvasm index e8f0e2cf15..1a34638d8c 100644 --- a/tests/out/spv/control-flow.spvasm +++ b/tests/out/spv/control-flow.spvasm @@ -1,13 +1,13 @@ ; SPIR-V ; Version: 1.1 ; Generator: rspirv -; Bound: 63 +; Bound: 67 OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 -OpEntryPoint GLCompute %41 "main" %38 -OpExecutionMode %41 LocalSize 1 1 1 -OpDecorate %38 BuiltIn GlobalInvocationId +OpEntryPoint GLCompute %42 "main" %39 +OpExecutionMode %42 LocalSize 1 1 1 +OpDecorate %39 BuiltIn GlobalInvocationId %2 = OpTypeVoid %4 = OpTypeInt 32 1 %3 = OpConstant %4 1 @@ -15,113 +15,121 @@ OpDecorate %38 BuiltIn GlobalInvocationId %6 = OpConstant %4 2 %7 = OpConstant %4 3 %9 = OpTypeInt 32 0 -%8 = OpTypeVector %9 3 -%13 = OpTypeFunction %2 %4 -%19 = OpTypeFunction %2 -%36 = OpTypePointer Function %4 -%39 = OpTypePointer Input %8 -%38 = OpVariable %39 Input -%43 = OpConstant %9 2 -%44 = OpConstant %9 1 -%45 = OpConstant %9 72 -%46 = OpConstant %9 264 -%12 = OpFunction %2 None %13 -%11 = OpFunctionParameter %4 -%10 = OpLabel -OpBranch %14 -%14 = OpLabel -OpSelectionMerge %15 None -OpSwitch %11 %16 -%16 = OpLabel +%8 = OpConstant %9 0 +%10 = OpTypeVector %9 3 +%14 = OpTypeFunction %2 %4 +%20 = OpTypeFunction %2 +%37 = OpTypePointer Function %4 +%40 = OpTypePointer Input %10 +%39 = OpVariable %40 Input +%44 = OpConstant %9 2 +%45 = OpConstant %9 1 +%46 = OpConstant %9 72 +%47 = OpConstant %9 264 +%13 = OpFunction %2 None %14 +%12 = OpFunctionParameter %4 +%11 = OpLabel OpBranch %15 %15 = OpLabel +OpSelectionMerge %16 None +OpSwitch %12 %17 +%17 = OpLabel +OpBranch %16 +%16 = OpLabel OpReturn OpFunctionEnd -%18 = OpFunction %2 None %19 -%17 = OpLabel -OpBranch %20 -%20 = OpLabel -OpSelectionMerge %21 None -OpSwitch %5 %22 0 %23 -%23 = OpLabel -OpBranch %21 -%22 = OpLabel +%19 = OpFunction %2 None %20 +%18 = OpLabel OpBranch %21 %21 = OpLabel +OpSelectionMerge %22 None +OpSwitch %5 %23 0 %24 +%24 = OpLabel +OpBranch %22 +%23 = OpLabel +OpBranch %22 +%22 = OpLabel OpReturn OpFunctionEnd -%26 = OpFunction %2 None %13 -%25 = OpFunctionParameter %4 -%24 = OpLabel -OpBranch %27 -%27 = OpLabel +%27 = OpFunction %2 None %14 +%26 = OpFunctionParameter %4 +%25 = OpLabel OpBranch %28 %28 = OpLabel -OpLoopMerge %29 %31 None -OpBranch %30 -%30 = OpLabel -OpSelectionMerge %32 None -OpSwitch %25 %33 1 %34 -%34 = OpLabel +OpBranch %29 +%29 = OpLabel +OpLoopMerge %30 %32 None OpBranch %31 +%31 = OpLabel +OpSelectionMerge %33 None +OpSwitch %26 %34 1 %35 +%35 = OpLabel +OpBranch %32 +%34 = OpLabel +OpBranch %33 %33 = OpLabel OpBranch %32 %32 = OpLabel -OpBranch %31 -%31 = OpLabel -OpBranch %28 -%29 = OpLabel +OpBranch %29 +%30 = OpLabel OpReturn OpFunctionEnd -%41 = OpFunction %2 None %19 -%37 = OpLabel -%35 = OpVariable %36 Function -%40 = OpLoad %8 %38 -OpBranch %42 -%42 = OpLabel -OpControlBarrier %43 %44 %45 -OpControlBarrier %43 %43 %46 -OpSelectionMerge %47 None -OpSwitch %3 %48 +%42 = OpFunction %2 None %20 +%38 = OpLabel +%36 = OpVariable %37 Function +%41 = OpLoad %10 %39 +OpBranch %43 +%43 = OpLabel +OpControlBarrier %44 %45 %46 +OpControlBarrier %44 %44 %47 +OpSelectionMerge %48 None +OpSwitch %3 %49 +%49 = OpLabel +OpStore %36 %3 +OpBranch %48 %48 = OpLabel -OpStore %35 %3 -OpBranch %47 -%47 = OpLabel -%49 = OpLoad %4 %35 -OpSelectionMerge %50 None -OpSwitch %49 %51 1 %52 2 %53 3 %54 4 %55 -%52 = OpLabel -OpStore %35 %5 -OpBranch %50 +%50 = OpLoad %4 %36 +OpSelectionMerge %51 None +OpSwitch %50 %52 1 %53 2 %54 3 %55 4 %56 %53 = OpLabel -OpStore %35 %3 -OpBranch %50 +OpStore %36 %5 +OpBranch %51 %54 = OpLabel -OpStore %35 %6 -OpBranch %55 +OpStore %36 %3 +OpBranch %51 %55 = OpLabel -OpBranch %50 +OpStore %36 %6 +OpBranch %56 +%56 = OpLabel +OpBranch %51 +%52 = OpLabel +OpStore %36 %7 +OpBranch %51 %51 = OpLabel -OpStore %35 %7 -OpBranch %50 -%50 = OpLabel -%56 = OpLoad %4 %35 OpSelectionMerge %57 None -OpSwitch %56 %58 1 %59 2 %60 3 %61 4 %62 +OpSwitch %8 %58 0 %59 %59 = OpLabel -OpStore %35 %5 OpBranch %57 -%60 = OpLabel -OpStore %35 %3 +%58 = OpLabel +OpBranch %57 +%57 = OpLabel +%60 = OpLoad %4 %36 +OpSelectionMerge %61 None +OpSwitch %60 %62 1 %63 2 %64 3 %65 4 %66 +%63 = OpLabel +OpStore %36 %5 +OpBranch %61 +%64 = OpLabel +OpStore %36 %3 +OpReturn +%65 = OpLabel +OpStore %36 %6 +OpBranch %66 +%66 = OpLabel +OpReturn +%62 = OpLabel +OpStore %36 %7 OpReturn %61 = OpLabel -OpStore %35 %6 -OpBranch %62 -%62 = OpLabel -OpReturn -%58 = OpLabel -OpStore %35 %7 -OpReturn -%57 = OpLabel OpReturn OpFunctionEnd \ No newline at end of file diff --git a/tests/out/wgsl/control-flow.wgsl b/tests/out/wgsl/control-flow.wgsl index 511c08aa5f..cdb3fc53c9 100644 --- a/tests/out/wgsl/control-flow.wgsl +++ b/tests/out/wgsl/control-flow.wgsl @@ -56,8 +56,12 @@ fn main([[builtin(global_invocation_id)]] global_id: vec3) { pos = 3; } } - let e9: i32 = pos; - switch(e9) { + switch(0u) { + case 0u: { + } + } + let e10: i32 = pos; + switch(e10) { case 1: { pos = 0; break;