Refactor Switch statement, implement on WGSL

This commit is contained in:
Dzmitry Malyshau
2020-12-04 19:26:01 -05:00
committed by Dzmitry Malyshau
parent 5fe9429a63
commit f7ca7f2aff
9 changed files with 117 additions and 54 deletions

View File

@@ -1023,15 +1023,15 @@ impl<'a, W: Write> Writer<'a, W> {
writeln!(self.out, ") {{")?;
// Write all cases
for (label, (block, fallthrough)) in cases {
writeln!(self.out, "{}case {}:", "\t".repeat(indent + 1), label)?;
for case in cases {
writeln!(self.out, "{}case {}:", "\t".repeat(indent + 1), case.value)?;
for sta in block {
for sta in case.body.iter() {
self.write_stmt(sta, ctx, indent + 2)?;
}
// Write `break;` if the block isn't fallthrough
if fallthrough.is_none() {
if case.fall_through {
writeln!(self.out, "{}break;", "\t".repeat(indent + 2))?;
}
}

View File

@@ -505,10 +505,10 @@ impl<W: Write> Writer<W> {
self.put_expression(selector, context)?;
writeln!(self.out, ") {{")?;
let lcase = level.next();
for (&value, &(ref block, ref fall_through)) in cases.iter() {
writeln!(self.out, "{}case {}: {{", lcase, value)?;
self.put_block(lcase.next(), block, context, return_value)?;
if fall_through.is_none() {
for case in cases.iter() {
writeln!(self.out, "{}case {}: {{", lcase, case.value)?;
self.put_block(lcase.next(), &case.body, context, return_value)?;
if case.fall_through {
writeln!(self.out, "{}break;", lcase.next())?;
}
writeln!(self.out, "{}}}", lcase)?;

View File

@@ -8,11 +8,11 @@ pomelo! {
use crate::{
proc::{ensure_block_returns, Typifier},
Arena, BinaryOperator, Binding, Block, Constant,
ConstantInner, EntryPoint, Expression, FallThrough,
FastHashMap, Function, GlobalVariable, Handle, Interpolation,
ConstantInner, EntryPoint, Expression,
Function, GlobalVariable, Handle, Interpolation,
LocalVariable, MemberOrigin, SampleLevel, ScalarKind,
Statement, StorageAccess,
StorageClass, StructMember, Type, TypeInner, UnaryOperator,
Statement, StorageAccess, StorageClass, StructMember,
SwitchCase, Type, TypeInner, UnaryOperator,
};
}
%token #[derive(Debug)] #[cfg_attr(test, derive(PartialEq))] pub enum Token {};
@@ -62,8 +62,8 @@ pomelo! {
%type jump_statement Statement;
%type iteration_statement Statement;
%type selection_statement Statement;
%type switch_statement_list Vec<(Option<i32>, Block, Option<FallThrough>)>;
%type switch_statement (Option<i32>, Block, Option<FallThrough>);
%type switch_statement_list Vec<(Option<i32>, Block, bool)>;
%type switch_statement (Option<i32>, Block, bool);
%type for_init_statement Statement;
%type for_rest_statement (Option<ExpressionRule>, Option<ExpressionRule>);
%type condition_opt Option<ExpressionRule>;
@@ -847,12 +847,16 @@ pomelo! {
selection_statement ::= Switch LeftParen expression(e) RightParen LeftBrace switch_statement_list(ls) RightBrace {
let mut default = Vec::new();
let mut cases = FastHashMap::default();
for (v, s, ft) in ls {
if let Some(v) = v {
cases.insert(v, (s, ft));
let mut cases = Vec::new();
for (v, body, fall_through) in ls {
if let Some(value) = v {
cases.push(SwitchCase {
value,
body,
fall_through,
});
} else {
default.extend_from_slice(&s);
default.extend_from_slice(&body);
}
}
Statement::Switch {
@@ -870,18 +874,18 @@ pomelo! {
ssl
}
switch_statement ::= Case IntConstant(v) Colon statement_list(sl) {
let fallthrough = match sl.last() {
Some(Statement::Break) => None,
_ => Some(FallThrough),
let fall_through = match sl.last() {
Some(&Statement::Break) => false,
_ => true,
};
(Some(v.1 as i32), sl, fallthrough)
(Some(v.1 as i32), sl, fall_through)
}
switch_statement ::= Default Colon statement_list(sl) {
let fallthrough = match sl.last() {
Some(Statement::Break) => Some(FallThrough),
_ => None,
let fall_through = match sl.last() {
Some(&Statement::Break) => true,
_ => false,
};
(None, sl, fallthrough)
(None, sl, fall_through)
}
iteration_statement ::= While LeftParen expression(e) RightParen compound_statement_no_new_scope(sl) {

View File

@@ -271,35 +271,29 @@ impl FlowGraph {
} => {
let merge_node_index = self.block_to_node[&node.merge.unwrap().merge_block_id];
let mut result = node.block.clone();
let mut cases = FastHashMap::default();
let mut cases = Vec::with_capacity(targets.len());
for i in 0..targets.len() {
let left_target_node_index = self.block_to_node[&targets[i].1];
let fallthrough: Option<crate::FallThrough> = if i < targets.len() - 1 {
let fall_through = if i < targets.len() - 1 {
let right_target_node_index = self.block_to_node[&targets[i + 1].1];
if has_path_connecting(
has_path_connecting(
&self.flow,
left_target_node_index,
right_target_node_index,
None,
) {
Some(crate::FallThrough {})
} else {
None
}
)
} else {
None
false
};
cases.insert(
targets[i].0,
(
self.naga_traverse(left_target_node_index, Some(merge_node_index))?,
fallthrough,
),
);
cases.push(crate::SwitchCase {
value: targets[i].0,
body: self
.naga_traverse(left_target_node_index, Some(merge_node_index))?,
fall_through,
});
}
result.push(crate::Statement::Switch {

View File

@@ -198,7 +198,7 @@ impl<'a> Lexer<'a> {
}
}
fn _next_sint_literal(&mut self) -> Result<i32, Error<'a>> {
pub(super) fn next_sint_literal(&mut self) -> Result<i32, Error<'a>> {
match self.next() {
Token::Number(word) => word.parse().map_err(|err| Error::BadInteger(word, err)),
other => Err(Error::Unexpected(other)),

View File

@@ -1460,6 +1460,63 @@ impl Parser {
reject,
}
}
"switch" => {
lexer.expect(Token::Paren('('))?;
let selector =
self.parse_general_expression(lexer, context.as_expression())?;
lexer.expect(Token::Paren(')'))?;
lexer.expect(Token::Paren('{'))?;
let mut cases = Vec::new();
let mut default = Vec::new();
loop {
match lexer.next() {
Token::Word("case") => loop {
let value = lexer.next_sint_literal()?;
lexer.expect(Token::Separator(':'))?;
let mut body = Vec::new();
if lexer.skip(Token::Separator(',')) {
cases.push(crate::SwitchCase {
value,
body,
fall_through: true,
});
} else {
lexer.expect(Token::Paren('{'))?;
let fall_through = loop {
if lexer.skip(Token::Word("fallthrough")) {
lexer.expect(Token::Separator(';'))?;
lexer.expect(Token::Paren('}'))?;
break true;
}
if lexer.skip(Token::Paren('}')) {
break false;
}
let s =
self.parse_statement(lexer, context.reborrow())?;
body.push(s);
};
cases.push(crate::SwitchCase {
value,
body,
fall_through,
});
break;
}
},
Token::Word("default") => {
lexer.expect(Token::Separator(':'))?;
default = self.parse_block(lexer, context.reborrow())?;
}
Token::Paren('}') => break,
other => return Err(Error::Unexpected(other)),
}
}
crate::Statement::Switch {
selector,
cases,
default,
}
}
"loop" => {
let mut body = Vec::new();
let mut continuing = Vec::new();

View File

@@ -666,12 +666,20 @@ pub enum Expression {
/// A code block is just a vector of statements.
pub type Block = Vec<Statement>;
/// Marker type, used for falling through in a switch statement.
/// A case for a switch statement.
// Clone is used only for error reporting and is not intended for end users
#[derive(Clone, Debug, PartialEq)]
#[derive(Clone, Debug)]
#[cfg_attr(feature = "serialize", derive(Serialize))]
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
pub struct FallThrough;
pub struct SwitchCase {
/// Value, upon which the case is considered true.
pub value: i32,
/// Body of the cae.
pub body: Block,
/// If true, the control flow continues to the next case in the list,
/// or default.
pub fall_through: bool,
}
/// Instructions which make up an executable block.
// Clone is used only for error reporting and is not intended for end users
@@ -690,7 +698,7 @@ pub enum Statement {
/// Conditionally executes one of multiple blocks, based on the value of the selector.
Switch {
selector: Handle<Expression>, //int
cases: FastHashMap<i32, (Block, Option<FallThrough>)>,
cases: Vec<SwitchCase>,
default: Block,
},
/// Executes a block repeatedly.

View File

@@ -148,8 +148,8 @@ where
ref default,
} => {
self.traverse_expr(selector);
for &(ref case, _) in cases.values() {
self.traverse(case);
for case in cases.iter() {
self.traverse(&case.body);
}
self.traverse(default);
}

View File

@@ -22,9 +22,9 @@ pub fn ensure_block_returns(block: &mut crate::Block) {
ref mut cases,
ref mut default,
}) => {
for case in cases.values_mut() {
if let (ref mut b, None) = *case {
ensure_block_returns(b);
for case in cases.iter_mut() {
if !case.fall_through {
ensure_block_returns(&mut case.body);
}
}
ensure_block_returns(default);