From f7ca7f2aff384481dcd17132e277ad59a789f067 Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Fri, 4 Dec 2020 19:26:01 -0500 Subject: [PATCH] Refactor Switch statement, implement on WGSL --- src/back/glsl/mod.rs | 8 +++--- src/back/msl/writer.rs | 8 +++--- src/front/glsl/parser.rs | 42 +++++++++++++++-------------- src/front/spv/flow.rs | 28 ++++++++------------ src/front/wgsl/lexer.rs | 2 +- src/front/wgsl/mod.rs | 57 ++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 16 ++++++++--- src/proc/interface.rs | 4 +-- src/proc/terminator.rs | 6 ++--- 9 files changed, 117 insertions(+), 54 deletions(-) diff --git a/src/back/glsl/mod.rs b/src/back/glsl/mod.rs index 96ec029d0f..98bfe02c9d 100644 --- a/src/back/glsl/mod.rs +++ b/src/back/glsl/mod.rs @@ -1023,15 +1023,15 @@ impl<'a, W: Write> Writer<'a, W> { writeln!(self.out, ") {{")?; // Write all cases - for (label, (block, fallthrough)) in cases { - writeln!(self.out, "{}case {}:", "\t".repeat(indent + 1), label)?; + for case in cases { + writeln!(self.out, "{}case {}:", "\t".repeat(indent + 1), case.value)?; - for sta in block { + for sta in case.body.iter() { self.write_stmt(sta, ctx, indent + 2)?; } // Write `break;` if the block isn't fallthrough - if fallthrough.is_none() { + if case.fall_through { writeln!(self.out, "{}break;", "\t".repeat(indent + 2))?; } } diff --git a/src/back/msl/writer.rs b/src/back/msl/writer.rs index 0a9f29e04d..936f78ac50 100644 --- a/src/back/msl/writer.rs +++ b/src/back/msl/writer.rs @@ -505,10 +505,10 @@ impl Writer { self.put_expression(selector, context)?; writeln!(self.out, ") {{")?; let lcase = level.next(); - for (&value, &(ref block, ref fall_through)) in cases.iter() { - writeln!(self.out, "{}case {}: {{", lcase, value)?; - self.put_block(lcase.next(), block, context, return_value)?; - if fall_through.is_none() { + for case in cases.iter() { + writeln!(self.out, "{}case {}: {{", lcase, case.value)?; + self.put_block(lcase.next(), &case.body, context, return_value)?; + if case.fall_through { writeln!(self.out, "{}break;", lcase.next())?; } writeln!(self.out, "{}}}", lcase)?; diff --git a/src/front/glsl/parser.rs b/src/front/glsl/parser.rs index 29cbbd3fc3..797917dd96 100644 --- a/src/front/glsl/parser.rs +++ b/src/front/glsl/parser.rs @@ -8,11 +8,11 @@ pomelo! { use crate::{ proc::{ensure_block_returns, Typifier}, Arena, BinaryOperator, Binding, Block, Constant, - ConstantInner, EntryPoint, Expression, FallThrough, - FastHashMap, Function, GlobalVariable, Handle, Interpolation, + ConstantInner, EntryPoint, Expression, + Function, GlobalVariable, Handle, Interpolation, LocalVariable, MemberOrigin, SampleLevel, ScalarKind, - Statement, StorageAccess, - StorageClass, StructMember, Type, TypeInner, UnaryOperator, + Statement, StorageAccess, StorageClass, StructMember, + SwitchCase, Type, TypeInner, UnaryOperator, }; } %token #[derive(Debug)] #[cfg_attr(test, derive(PartialEq))] pub enum Token {}; @@ -62,8 +62,8 @@ pomelo! { %type jump_statement Statement; %type iteration_statement Statement; %type selection_statement Statement; - %type switch_statement_list Vec<(Option, Block, Option)>; - %type switch_statement (Option, Block, Option); + %type switch_statement_list Vec<(Option, Block, bool)>; + %type switch_statement (Option, Block, bool); %type for_init_statement Statement; %type for_rest_statement (Option, Option); %type condition_opt Option; @@ -847,12 +847,16 @@ pomelo! { selection_statement ::= Switch LeftParen expression(e) RightParen LeftBrace switch_statement_list(ls) RightBrace { let mut default = Vec::new(); - let mut cases = FastHashMap::default(); - for (v, s, ft) in ls { - if let Some(v) = v { - cases.insert(v, (s, ft)); + let mut cases = Vec::new(); + for (v, body, fall_through) in ls { + if let Some(value) = v { + cases.push(SwitchCase { + value, + body, + fall_through, + }); } else { - default.extend_from_slice(&s); + default.extend_from_slice(&body); } } Statement::Switch { @@ -870,18 +874,18 @@ pomelo! { ssl } switch_statement ::= Case IntConstant(v) Colon statement_list(sl) { - let fallthrough = match sl.last() { - Some(Statement::Break) => None, - _ => Some(FallThrough), + let fall_through = match sl.last() { + Some(&Statement::Break) => false, + _ => true, }; - (Some(v.1 as i32), sl, fallthrough) + (Some(v.1 as i32), sl, fall_through) } switch_statement ::= Default Colon statement_list(sl) { - let fallthrough = match sl.last() { - Some(Statement::Break) => Some(FallThrough), - _ => None, + let fall_through = match sl.last() { + Some(&Statement::Break) => true, + _ => false, }; - (None, sl, fallthrough) + (None, sl, fall_through) } iteration_statement ::= While LeftParen expression(e) RightParen compound_statement_no_new_scope(sl) { diff --git a/src/front/spv/flow.rs b/src/front/spv/flow.rs index 32eac66941..c11a1eda0f 100644 --- a/src/front/spv/flow.rs +++ b/src/front/spv/flow.rs @@ -271,35 +271,29 @@ impl FlowGraph { } => { let merge_node_index = self.block_to_node[&node.merge.unwrap().merge_block_id]; let mut result = node.block.clone(); - - let mut cases = FastHashMap::default(); + let mut cases = Vec::with_capacity(targets.len()); for i in 0..targets.len() { let left_target_node_index = self.block_to_node[&targets[i].1]; - let fallthrough: Option = if i < targets.len() - 1 { + let fall_through = if i < targets.len() - 1 { let right_target_node_index = self.block_to_node[&targets[i + 1].1]; - if has_path_connecting( + has_path_connecting( &self.flow, left_target_node_index, right_target_node_index, None, - ) { - Some(crate::FallThrough {}) - } else { - None - } + ) } else { - None + false }; - cases.insert( - targets[i].0, - ( - self.naga_traverse(left_target_node_index, Some(merge_node_index))?, - fallthrough, - ), - ); + cases.push(crate::SwitchCase { + value: targets[i].0, + body: self + .naga_traverse(left_target_node_index, Some(merge_node_index))?, + fall_through, + }); } result.push(crate::Statement::Switch { diff --git a/src/front/wgsl/lexer.rs b/src/front/wgsl/lexer.rs index b991ecf619..0dd2627ef3 100644 --- a/src/front/wgsl/lexer.rs +++ b/src/front/wgsl/lexer.rs @@ -198,7 +198,7 @@ impl<'a> Lexer<'a> { } } - fn _next_sint_literal(&mut self) -> Result> { + pub(super) fn next_sint_literal(&mut self) -> Result> { match self.next() { Token::Number(word) => word.parse().map_err(|err| Error::BadInteger(word, err)), other => Err(Error::Unexpected(other)), diff --git a/src/front/wgsl/mod.rs b/src/front/wgsl/mod.rs index e2c95ee779..620e221e71 100644 --- a/src/front/wgsl/mod.rs +++ b/src/front/wgsl/mod.rs @@ -1460,6 +1460,63 @@ impl Parser { reject, } } + "switch" => { + lexer.expect(Token::Paren('('))?; + let selector = + self.parse_general_expression(lexer, context.as_expression())?; + lexer.expect(Token::Paren(')'))?; + lexer.expect(Token::Paren('{'))?; + let mut cases = Vec::new(); + let mut default = Vec::new(); + loop { + match lexer.next() { + Token::Word("case") => loop { + let value = lexer.next_sint_literal()?; + lexer.expect(Token::Separator(':'))?; + let mut body = Vec::new(); + if lexer.skip(Token::Separator(',')) { + cases.push(crate::SwitchCase { + value, + body, + fall_through: true, + }); + } else { + lexer.expect(Token::Paren('{'))?; + let fall_through = loop { + if lexer.skip(Token::Word("fallthrough")) { + lexer.expect(Token::Separator(';'))?; + lexer.expect(Token::Paren('}'))?; + break true; + } + if lexer.skip(Token::Paren('}')) { + break false; + } + let s = + self.parse_statement(lexer, context.reborrow())?; + body.push(s); + }; + cases.push(crate::SwitchCase { + value, + body, + fall_through, + }); + break; + } + }, + Token::Word("default") => { + lexer.expect(Token::Separator(':'))?; + default = self.parse_block(lexer, context.reborrow())?; + } + Token::Paren('}') => break, + other => return Err(Error::Unexpected(other)), + } + } + crate::Statement::Switch { + selector, + cases, + default, + } + } "loop" => { let mut body = Vec::new(); let mut continuing = Vec::new(); diff --git a/src/lib.rs b/src/lib.rs index 936c3b69f5..7b141b5388 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -666,12 +666,20 @@ pub enum Expression { /// A code block is just a vector of statements. pub type Block = Vec; -/// Marker type, used for falling through in a switch statement. +/// A case for a switch statement. // Clone is used only for error reporting and is not intended for end users -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] -pub struct FallThrough; +pub struct SwitchCase { + /// Value, upon which the case is considered true. + pub value: i32, + /// Body of the cae. + pub body: Block, + /// If true, the control flow continues to the next case in the list, + /// or default. + pub fall_through: bool, +} /// Instructions which make up an executable block. // Clone is used only for error reporting and is not intended for end users @@ -690,7 +698,7 @@ pub enum Statement { /// Conditionally executes one of multiple blocks, based on the value of the selector. Switch { selector: Handle, //int - cases: FastHashMap)>, + cases: Vec, default: Block, }, /// Executes a block repeatedly. diff --git a/src/proc/interface.rs b/src/proc/interface.rs index b512452fe2..e3835f263a 100644 --- a/src/proc/interface.rs +++ b/src/proc/interface.rs @@ -148,8 +148,8 @@ where ref default, } => { self.traverse_expr(selector); - for &(ref case, _) in cases.values() { - self.traverse(case); + for case in cases.iter() { + self.traverse(&case.body); } self.traverse(default); } diff --git a/src/proc/terminator.rs b/src/proc/terminator.rs index 8b4d6ff330..1ad1429813 100644 --- a/src/proc/terminator.rs +++ b/src/proc/terminator.rs @@ -22,9 +22,9 @@ pub fn ensure_block_returns(block: &mut crate::Block) { ref mut cases, ref mut default, }) => { - for case in cases.values_mut() { - if let (ref mut b, None) = *case { - ensure_block_returns(b); + for case in cases.iter_mut() { + if !case.fall_through { + ensure_block_returns(&mut case.body); } } ensure_block_returns(default);