diff --git a/grammars/wgsl.pest b/grammars/wgsl.pest index 045248c583..3978c9f13f 100644 --- a/grammars/wgsl.pest +++ b/grammars/wgsl.pest @@ -75,8 +75,7 @@ const_literal = { int_literal // | UINT_LITERAL // | FLOAT_LITERAL - | "true" - | "false" + | bool_literal } const_expr = { @@ -106,9 +105,11 @@ scalar_type = { | "u32" } +return_statement = { "return" ~ primary_expression? } + statement = { ";" - | "return" ~ primary_expression ~ ";" + | return_statement ~ ";" // | if_stmt // | unless_stmt // | regardless_stmt @@ -122,12 +123,13 @@ statement = { // | assignment_stmt SEMICOLON } -primary_expression = { +primary_expression = _{ const_literal } ident = @{ ASCII_ALPHA ~ (ASCII_ALPHANUMERIC | "_")* } int_literal = @{ ("-"? ~ "0x"? ~ (ASCII_DIGIT | 'a'..'f' | 'A'..'F')+) | "0" | ("-"? ~ '1'..'9' ~ ASCII_DIGIT*) } +bool_literal = @{ "true" | "false" } string = @{ "\"" ~ ( "\"\"" | (!"\"" ~ ANY) )* ~ "\"" } WHITESPACE = _{ " " | "\t" | "\n" } diff --git a/src/back/msl.rs b/src/back/msl.rs index 956dfddaff..9a5813e9f6 100644 --- a/src/back/msl.rs +++ b/src/back/msl.rs @@ -275,6 +275,7 @@ fn scalar_kind_string(kind: crate::ScalarKind) -> &'static str { crate::ScalarKind::Float => "float", crate::ScalarKind::Sint => "signed int", crate::ScalarKind::Uint => "unsigned int", + crate::ScalarKind::Bool => "bool", } } @@ -361,6 +362,10 @@ impl Writer { write!(self.out, "{}", value)?; crate::ScalarKind::Float } + crate::ConstantInner::Bool(value) => { + write!(self.out, "{}", value)?; + crate::ScalarKind::Bool + } }; let width = 32; //TODO: not sure how to get that... Ok(MaybeOwned::Owned(crate::TypeInner::Scalar { kind, width })) diff --git a/src/front/wgsl.rs b/src/front/wgsl.rs index d08a59dd85..9106f66e9a 100644 --- a/src/front/wgsl.rs +++ b/src/front/wgsl.rs @@ -13,6 +13,7 @@ pub enum Error { Pest(pest::error::Error), BadInt(std::num::ParseIntError), BadStorageClass(String), + BadBool(String), UnknownType(String), } impl From> for Error { @@ -41,7 +42,7 @@ impl Parser { Ok(pair.as_str().parse()?) } - fn _parse_int_literal(pair: pest::iterators::Pair) -> Result { + fn parse_int_literal(pair: pest::iterators::Pair) -> Result { let istr = pair.as_str(); let (sign, istr) = match &istr[..1] { "_" => (-1, &istr[1..]), @@ -159,11 +160,58 @@ impl Parser { Ok(crate::TypeInner::Struct { members }) } + fn parse_const_literal( + const_literal: pest::iterators::Pair, + const_store: &mut Storage, + ) -> Result, Error> { + let inner = match const_literal.as_rule() { + Rule::int_literal => { + let value = Self::parse_int_literal(const_literal)?; + crate::ConstantInner::Sint(value as i64) + } + Rule::bool_literal => { + let value = match const_literal.as_str() { + "true" => true, + "false" => false, + other => return Err(Error::BadBool(other.to_owned())), + }; + crate::ConstantInner::Bool(value) + } + ref other => panic!("Unknown const literal {:?}", other), + }; + Ok(const_store.append(crate::Constant { + name: None, + specialization: None, + inner, + })) + } + + fn parse_primary_expression( + &self, + primary_expression: pest::iterators::Pair, + function: &mut crate::Function, + const_store: &mut Storage, + ) -> Result, Error> { + let expression = match primary_expression.as_rule() { + Rule::const_literal => { + let const_literal = primary_expression.into_inner().next().unwrap(); + let token = Self::parse_const_literal(const_literal, const_store)?; + crate::Expression::Constant(token) + } + ref other => panic!("Unknown expression {:?}", other), + }; + Ok(function.expressions.append(expression)) + } + fn parse_function_decl( &self, function_decl: pest::iterators::Pair, module: &mut crate::Module, ) -> Result, Error> { + enum Ident { + Parameter(u8), + } + let mut lookup_symbols = FastHashMap::default(); assert_eq!(function_decl.as_rule(), Rule::function_decl); let mut function_decl_pairs = function_decl.into_inner(); @@ -171,24 +219,53 @@ impl Parser { assert_eq!(function_header.as_rule(), Rule::function_header); let mut function_header_pairs = function_header.into_inner(); let fun_name = function_header_pairs.next().unwrap().as_str().to_owned(); - let param_list = function_header_pairs.next().unwrap(); - assert_eq!(param_list.as_rule(), Rule::param_list); - let function_type_decl = function_header_pairs.next().unwrap(); - let fun = crate::Function { + let mut fun = crate::Function { name: Some(fun_name), control: spirv::FunctionControl::empty(), parameter_types: Vec::new(), - return_type: if function_type_decl.as_rule() == Rule::type_decl { - Some(self.parse_type_decl(function_type_decl, &mut module.types)?) - } else { - None - }, + return_type: None, expressions: Storage::new(), body: Vec::new(), }; + let param_list = function_header_pairs.next().unwrap(); + assert_eq!(param_list.as_rule(), Rule::param_list); + for (i, variable_ident_decl) in param_list.into_inner().enumerate() { + assert_eq!(variable_ident_decl.as_rule(), Rule::variable_ident_decl); + let mut variable_ident_decl_pairs = variable_ident_decl.into_inner(); + let param_name = variable_ident_decl_pairs.next().unwrap().as_str().to_owned(); + lookup_symbols.insert(param_name, Ident::Parameter(i as u8)); + let param_type_decl = variable_ident_decl_pairs.next().unwrap(); + let ty = self.parse_type_decl(param_type_decl, &mut module.types)?; + fun.parameter_types.push(ty); + } + let function_type_decl = function_header_pairs.next().unwrap(); + if function_type_decl.as_rule() == Rule::type_decl { + let ty = self.parse_type_decl(function_type_decl, &mut module.types)?; + fun.return_type = Some(ty); + } let function_body = function_decl_pairs.next().unwrap(); assert_eq!(function_body.as_rule(), Rule::body_stmt); + for statement in function_body.into_inner() { + assert_eq!(statement.as_rule(), Rule::statement); + let mut statement_pairs = statement.into_inner(); + let first_statement = match statement_pairs.next() { + Some(st) => st, + None => continue, + }; + let stmt = match first_statement.as_rule() { + Rule::return_statement => { + let mut return_pairs = first_statement.into_inner(); + let value = match return_pairs.next() { + Some(st) => Some(self.parse_primary_expression(st, &mut fun, &mut module.constants)?), + None => None, + }; + crate::Statement::Return { value } + } + ref other => panic!("Unknown statement {:?}", other), + }; + fun.body.push(stmt); + } Ok(module.functions.append(fun)) } diff --git a/src/lib.rs b/src/lib.rs index 579e29dbd7..70e28471c8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -39,6 +39,7 @@ pub enum ScalarKind { Sint, Uint, Float, + Bool, } #[repr(u8)] @@ -94,6 +95,7 @@ pub enum ConstantInner { Sint(i64), Uint(u64), Float(f64), + Bool(bool), } #[derive(Clone, Debug, PartialEq)]