Allow unsigned integers in switch

This commit is contained in:
João Capucho
2021-09-19 22:31:38 +01:00
committed by Dzmitry Malyshau
parent 6a57559070
commit d5fc05e8a4
15 changed files with 204 additions and 142 deletions

View File

@@ -1427,11 +1427,18 @@ impl<'a, W: Write> Writer<'a, W> {
write!(self.out, "switch(")?;
self.write_expr(selector, ctx)?;
writeln!(self.out, ") {{")?;
let type_postfix = match *ctx.info[selector].ty.inner_with(&self.module.types) {
crate::TypeInner::Scalar {
kind: crate::ScalarKind::Uint,
..
} => "u",
_ => "",
};
// Write all cases
let l2 = level.next();
for case in cases {
writeln!(self.out, "{}case {}:", l2, case.value)?;
writeln!(self.out, "{}case {}{}:", l2, case.value, type_postfix)?;
for sta in case.body.iter() {
self.write_stmt(sta, ctx, l2.next())?;

View File

@@ -1370,13 +1370,24 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
write!(self.out, "switch(")?;
self.write_expr(module, selector, func_ctx)?;
writeln!(self.out, ") {{")?;
let type_postfix = match *func_ctx.info[selector].ty.inner_with(&module.types) {
crate::TypeInner::Scalar {
kind: crate::ScalarKind::Uint,
..
} => "u",
_ => "",
};
// Write all cases
let indent_level_1 = level.next();
let indent_level_2 = indent_level_1.next();
for case in cases {
writeln!(self.out, "{}case {}: {{", indent_level_1, case.value)?;
writeln!(
self.out,
"{}case {}{}: {{",
indent_level_1, case.value, type_postfix
)?;
if case.fall_through {
// Generate each fallthrough case statement in a new block. This is done to

View File

@@ -1439,10 +1439,17 @@ impl<W: Write> Writer<W> {
} => {
write!(self.out, "{}switch(", level)?;
self.put_expression(selector, &context.expression, true)?;
let type_postfix = match *context.expression.resolve_type(selector) {
crate::TypeInner::Scalar {
kind: crate::ScalarKind::Uint,
..
} => "u",
_ => "",
};
writeln!(self.out, ") {{")?;
let lcase = level.next();
for case in cases.iter() {
writeln!(self.out, "{}case {}: {{", lcase, case.value)?;
writeln!(self.out, "{}case {}{}: {{", lcase, case.value, type_postfix)?;
self.put_block(lcase.next(), &case.body, context)?;
if !case.fall_through {
writeln!(self.out, "{}break;", lcase.next())?;

View File

@@ -887,6 +887,13 @@ impl<W: Write> Writer<W> {
let all_fall_through = cases
.iter()
.all(|case| case.fall_through && case.body.is_empty());
let type_postfix = match *func_ctx.info[selector].ty.inner_with(&module.types) {
crate::TypeInner::Scalar {
kind: crate::ScalarKind::Uint,
..
} => "u",
_ => "",
};
let l2 = level.next();
if !cases.is_empty() {
@@ -896,11 +903,11 @@ impl<W: Write> Writer<W> {
}
if !all_fall_through && case.fall_through && case.body.is_empty() {
write_case = false;
write!(self.out, "{}, ", case.value)?;
write!(self.out, "{}{}, ", case.value, type_postfix)?;
continue;
} else {
write_case = true;
writeln!(self.out, "{}: {{", case.value)?;
writeln!(self.out, "{}{}: {{", case.value, type_postfix)?;
}
for sta in case.body.iter() {

View File

@@ -157,19 +157,12 @@ impl<'source> ParsingContext<'source> {
self.expect(parser, TokenValue::LeftParen)?;
let (mut selector, selector_meta) = {
let selector = {
let mut stmt = ctx.stmt_ctx();
let expr = self.parse_expression(parser, ctx, &mut stmt, body)?;
ctx.lower_expect(stmt, parser, expr, ExprPos::Rhs, body)?
ctx.lower_expect(stmt, parser, expr, ExprPos::Rhs, body)?.0
};
if let Some(crate::ScalarKind::Uint) = parser
.resolve_type(ctx, selector, selector_meta)?
.scalar_kind()
{
ctx.conversion(&mut selector, selector_meta, crate::ScalarKind::Sint, 4)?
}
self.expect(parser, TokenValue::RightParen)?;
ctx.emit_flush(body);

View File

@@ -21,7 +21,7 @@ use self::{
lexer::Lexer,
number_literals::{
get_f32_literal, get_i32_literal, get_u32_literal, parse_generic_non_negative_int_literal,
parse_non_negative_sint_literal, parse_sint_literal,
parse_non_negative_sint_literal,
},
};
use codespan_reporting::{
@@ -83,6 +83,7 @@ pub enum ExpectedToken<'a> {
ty: Option<NumberType>,
width: Option<Bytes>,
},
Integer,
Constant,
/// Expected: constant, parenthesized expression, identifier
PrimaryExpression,
@@ -218,6 +219,7 @@ impl<'a> Error<'a> {
)
}
},
ExpectedToken::Integer => "unsigned/signed integer literal".to_string(),
ExpectedToken::Constant => "constant".to_string(),
ExpectedToken::PrimaryExpression => "expression".to_string(),
ExpectedToken::AttributeSeparator => "attribute separator (',') or an end of the attribute list (']]')".to_string(),
@@ -1242,6 +1244,28 @@ impl Parser {
})
}
fn parse_switch_value<'a>(lexer: &mut Lexer<'a>, uint: bool) -> Result<i32, Error<'a>> {
let token_span = lexer.next();
let word = match token_span.0 {
Token::Number { value, width, .. } => {
if let Some(width) = width {
if width != 4 {
// Only 32-bit literals supported by the spec and naga for now!
return Err(Error::BadScalarWidth(token_span.1, width));
}
}
value
}
_ => return Err(Error::Unexpected(token_span, ExpectedToken::Integer)),
};
match uint {
true => get_u32_literal(word, token_span.1).map(|v| v as i32),
false => get_i32_literal(word, token_span.1),
}
}
fn parse_atomic_pointer<'a>(
&mut self,
lexer: &mut Lexer<'a>,
@@ -3425,6 +3449,11 @@ impl Parser {
lexer,
context.as_expression(block, &mut emitter),
)?;
let uint = Some(crate::ScalarKind::Uint)
== context
.as_expression(block, &mut emitter)
.resolve_type(selector)?
.scalar_kind();
lexer.expect(Token::Paren(')'))?;
block.extend(emitter.finish(context.expressions));
lexer.expect(Token::Paren('{'))?;
@@ -3438,7 +3467,7 @@ impl Parser {
// parse a list of values
let value = loop {
// TODO: Switch statements also allow for floats, bools and unsigned integers. See https://www.w3.org/TR/WGSL/#switch-statement
let value = parse_sint_literal(lexer, 4)?;
let value = Self::parse_switch_value(lexer, uint)?;
if lexer.skip(Token::Separator(',')) {
if lexer.skip(Token::Separator(':')) {
break value;

View File

@@ -69,36 +69,6 @@ pub fn get_f32_literal(word: &str, span: Span) -> Result<f32, Error<'_>> {
parsed_val.map_err(|e| Error::BadFloat(span, e))
}
pub(super) fn parse_sint_literal<'a>(
lexer: &mut Lexer<'a>,
width: Bytes,
) -> Result<i32, Error<'a>> {
let token_span = lexer.next();
if width != 4 {
// Only 32-bit literals supported by the spec and naga for now!
return Err(Error::BadScalarWidth(token_span.1, width));
}
match token_span {
(
Token::Number {
value,
ty: NumberType::Sint,
width: token_width,
},
span,
) if token_width.unwrap_or(4) == width => get_i32_literal(value, span),
other => Err(Error::Unexpected(
other,
ExpectedToken::Number {
ty: Some(NumberType::Sint),
width: Some(width),
},
)),
}
}
pub(super) fn _parse_uint_literal<'a>(
lexer: &mut Lexer<'a>,
width: Bytes,

View File

@@ -1260,7 +1260,7 @@ pub use block::Block;
pub struct SwitchCase {
/// Value, upon which the case is considered true.
pub value: i32,
/// Body of the cae.
/// Body of the case.
pub body: Block,
/// If true, the control flow continues to the next case in the list,
/// or default.

View File

@@ -90,7 +90,7 @@ pub enum FunctionError {
InvalidIfType(Handle<crate::Expression>),
#[error("The `switch` value {0:?} is not an integer scalar")]
InvalidSwitchType(Handle<crate::Expression>),
#[error("Multiple `switch` cases for {0} are present")]
#[error("Multiple `switch` cases for {0:?} are present")]
ConflictingSwitchCase(i32),
#[error("The pointer {0:?} doesn't relate to a valid destination for a store")]
InvalidStorePointer(Handle<crate::Expression>),
@@ -375,6 +375,10 @@ impl super::Validator {
ref default,
} => {
match *context.resolve_type(selector, &self.valid_expression_set)? {
Ti::Scalar {
kind: crate::ScalarKind::Uint,
width: _,
} => {}
Ti::Scalar {
kind: crate::ScalarKind::Sint,
width: _,

View File

@@ -32,6 +32,12 @@ fn main([[builtin(global_invocation_id)]] global_id: vec3<u32>) {
}
}
// switch with unsigned integer selectors
switch(0u) {
case 0u: {
}
}
// non-empty switch in last-statement-in-function position
switch (pos) {
case 1: {

View File

@@ -56,8 +56,12 @@ void main() {
default:
pos = 3;
}
int _e9 = pos;
switch(_e9) {
switch(0u) {
case 0u:
break;
}
int _e10 = pos;
switch(_e10) {
case 1:
pos = 0;
break;

View File

@@ -68,8 +68,13 @@ void main(uint3 global_id : SV_DispatchThreadID)
pos = 3;
}
}
int _expr9 = pos;
switch(_expr9) {
switch(0u) {
case 0u: {
break;
}
}
int _expr10 = pos;
switch(_expr10) {
case 1: {
pos = 0;
break;

View File

@@ -76,8 +76,15 @@ kernel void main1(
pos = 3;
}
}
int _e9 = pos;
switch(_e9) {
switch(0u) {
case 0u: {
break;
}
default: {
}
}
int _e10 = pos;
switch(_e10) {
case 1: {
pos = 0;
break;

View File

@@ -1,13 +1,13 @@
; SPIR-V
; Version: 1.1
; Generator: rspirv
; Bound: 63
; Bound: 67
OpCapability Shader
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %41 "main" %38
OpExecutionMode %41 LocalSize 1 1 1
OpDecorate %38 BuiltIn GlobalInvocationId
OpEntryPoint GLCompute %42 "main" %39
OpExecutionMode %42 LocalSize 1 1 1
OpDecorate %39 BuiltIn GlobalInvocationId
%2 = OpTypeVoid
%4 = OpTypeInt 32 1
%3 = OpConstant %4 1
@@ -15,113 +15,121 @@ OpDecorate %38 BuiltIn GlobalInvocationId
%6 = OpConstant %4 2
%7 = OpConstant %4 3
%9 = OpTypeInt 32 0
%8 = OpTypeVector %9 3
%13 = OpTypeFunction %2 %4
%19 = OpTypeFunction %2
%36 = OpTypePointer Function %4
%39 = OpTypePointer Input %8
%38 = OpVariable %39 Input
%43 = OpConstant %9 2
%44 = OpConstant %9 1
%45 = OpConstant %9 72
%46 = OpConstant %9 264
%12 = OpFunction %2 None %13
%11 = OpFunctionParameter %4
%10 = OpLabel
OpBranch %14
%14 = OpLabel
OpSelectionMerge %15 None
OpSwitch %11 %16
%16 = OpLabel
%8 = OpConstant %9 0
%10 = OpTypeVector %9 3
%14 = OpTypeFunction %2 %4
%20 = OpTypeFunction %2
%37 = OpTypePointer Function %4
%40 = OpTypePointer Input %10
%39 = OpVariable %40 Input
%44 = OpConstant %9 2
%45 = OpConstant %9 1
%46 = OpConstant %9 72
%47 = OpConstant %9 264
%13 = OpFunction %2 None %14
%12 = OpFunctionParameter %4
%11 = OpLabel
OpBranch %15
%15 = OpLabel
OpSelectionMerge %16 None
OpSwitch %12 %17
%17 = OpLabel
OpBranch %16
%16 = OpLabel
OpReturn
OpFunctionEnd
%18 = OpFunction %2 None %19
%17 = OpLabel
OpBranch %20
%20 = OpLabel
OpSelectionMerge %21 None
OpSwitch %5 %22 0 %23
%23 = OpLabel
OpBranch %21
%22 = OpLabel
%19 = OpFunction %2 None %20
%18 = OpLabel
OpBranch %21
%21 = OpLabel
OpSelectionMerge %22 None
OpSwitch %5 %23 0 %24
%24 = OpLabel
OpBranch %22
%23 = OpLabel
OpBranch %22
%22 = OpLabel
OpReturn
OpFunctionEnd
%26 = OpFunction %2 None %13
%25 = OpFunctionParameter %4
%24 = OpLabel
OpBranch %27
%27 = OpLabel
%27 = OpFunction %2 None %14
%26 = OpFunctionParameter %4
%25 = OpLabel
OpBranch %28
%28 = OpLabel
OpLoopMerge %29 %31 None
OpBranch %30
%30 = OpLabel
OpSelectionMerge %32 None
OpSwitch %25 %33 1 %34
%34 = OpLabel
OpBranch %29
%29 = OpLabel
OpLoopMerge %30 %32 None
OpBranch %31
%31 = OpLabel
OpSelectionMerge %33 None
OpSwitch %26 %34 1 %35
%35 = OpLabel
OpBranch %32
%34 = OpLabel
OpBranch %33
%33 = OpLabel
OpBranch %32
%32 = OpLabel
OpBranch %31
%31 = OpLabel
OpBranch %28
%29 = OpLabel
OpBranch %29
%30 = OpLabel
OpReturn
OpFunctionEnd
%41 = OpFunction %2 None %19
%37 = OpLabel
%35 = OpVariable %36 Function
%40 = OpLoad %8 %38
OpBranch %42
%42 = OpLabel
OpControlBarrier %43 %44 %45
OpControlBarrier %43 %43 %46
OpSelectionMerge %47 None
OpSwitch %3 %48
%42 = OpFunction %2 None %20
%38 = OpLabel
%36 = OpVariable %37 Function
%41 = OpLoad %10 %39
OpBranch %43
%43 = OpLabel
OpControlBarrier %44 %45 %46
OpControlBarrier %44 %44 %47
OpSelectionMerge %48 None
OpSwitch %3 %49
%49 = OpLabel
OpStore %36 %3
OpBranch %48
%48 = OpLabel
OpStore %35 %3
OpBranch %47
%47 = OpLabel
%49 = OpLoad %4 %35
OpSelectionMerge %50 None
OpSwitch %49 %51 1 %52 2 %53 3 %54 4 %55
%52 = OpLabel
OpStore %35 %5
OpBranch %50
%50 = OpLoad %4 %36
OpSelectionMerge %51 None
OpSwitch %50 %52 1 %53 2 %54 3 %55 4 %56
%53 = OpLabel
OpStore %35 %3
OpBranch %50
OpStore %36 %5
OpBranch %51
%54 = OpLabel
OpStore %35 %6
OpBranch %55
OpStore %36 %3
OpBranch %51
%55 = OpLabel
OpBranch %50
OpStore %36 %6
OpBranch %56
%56 = OpLabel
OpBranch %51
%52 = OpLabel
OpStore %36 %7
OpBranch %51
%51 = OpLabel
OpStore %35 %7
OpBranch %50
%50 = OpLabel
%56 = OpLoad %4 %35
OpSelectionMerge %57 None
OpSwitch %56 %58 1 %59 2 %60 3 %61 4 %62
OpSwitch %8 %58 0 %59
%59 = OpLabel
OpStore %35 %5
OpBranch %57
%60 = OpLabel
OpStore %35 %3
%58 = OpLabel
OpBranch %57
%57 = OpLabel
%60 = OpLoad %4 %36
OpSelectionMerge %61 None
OpSwitch %60 %62 1 %63 2 %64 3 %65 4 %66
%63 = OpLabel
OpStore %36 %5
OpBranch %61
%64 = OpLabel
OpStore %36 %3
OpReturn
%65 = OpLabel
OpStore %36 %6
OpBranch %66
%66 = OpLabel
OpReturn
%62 = OpLabel
OpStore %36 %7
OpReturn
%61 = OpLabel
OpStore %35 %6
OpBranch %62
%62 = OpLabel
OpReturn
%58 = OpLabel
OpStore %35 %7
OpReturn
%57 = OpLabel
OpReturn
OpFunctionEnd

View File

@@ -56,8 +56,12 @@ fn main([[builtin(global_invocation_id)]] global_id: vec3<u32>) {
pos = 3;
}
}
let e9: i32 = pos;
switch(e9) {
switch(0u) {
case 0u: {
}
}
let e10: i32 = pos;
switch(e10) {
case 1: {
pos = 0;
break;