From 6026e57404bf32ae7a2cb74c3acdf8852ebea301 Mon Sep 17 00:00:00 2001 From: Lachlan Sneff Date: Fri, 21 Aug 2020 00:13:25 -0400 Subject: [PATCH] [wgsl] Add more complete function calling support (#144) * Add function calling support to wgsl frontend * Fix external namespace with multiple namespaces * changes after code review * Don't re-tokenize std_namespace every time --- src/front/wgsl.rs | 158 ++++++++++++++++++++++++++++------------ src/lib.rs | 4 + test-data/function.wgsl | 12 +++ 3 files changed, 127 insertions(+), 47 deletions(-) create mode 100644 test-data/function.wgsl diff --git a/src/front/wgsl.rs b/src/front/wgsl.rs index 8ccdd41ad2..4c3fa54538 100644 --- a/src/front/wgsl.rs +++ b/src/front/wgsl.rs @@ -209,6 +209,8 @@ pub enum Error<'a> { ZeroStride, #[error("not a composite type: {0:?}")] NotCompositeType(crate::TypeInner), + #[error("function redefinition: `{0}`")] + FunctionRedefinition(&'a str), //MutabilityViolation(&'a str), // TODO: these could be replaced with more detailed errors #[error("other error")] @@ -237,7 +239,7 @@ impl<'a> Lexer<'a> { self.clone().next() } - fn expect(&mut self, expected: Token<'a>) -> Result<(), Error<'a>> { + fn expect(&mut self, expected: Token<'_>) -> Result<(), Error<'a>> { let token = self.next(); if token == expected { Ok(()) @@ -246,7 +248,7 @@ impl<'a> Lexer<'a> { } } - fn skip(&mut self, what: Token<'a>) -> bool { + fn skip(&mut self, what: Token<'_>) -> bool { let (token, rest) = lex::consume_token(self.input); if token == what { self.input = rest; @@ -296,7 +298,7 @@ impl<'a> Lexer<'a> { Ok(pair) } - fn take_until(&mut self, what: Token<'a>) -> Result, Error<'a>> { + fn take_until(&mut self, what: Token<'_>) -> Result, Error<'a>> { let original_input = self.input; let initial_len = self.input.len(); let mut used_len = 0; @@ -449,7 +451,8 @@ pub struct ParseError<'a> { pub struct Parser { scopes: Vec, lookup_type: FastHashMap>, - std_namespace: Option, + function_lookup: FastHashMap>, + std_namespace: Option>, } impl Parser { @@ -457,6 +460,7 @@ impl Parser { Parser { scopes: Vec::new(), lookup_type: FastHashMap::default(), + function_lookup: FastHashMap::default(), std_namespace: None, } } @@ -543,6 +547,49 @@ impl Parser { } } + fn parse_function_call<'a>( + &mut self, + lexer: &Lexer<'a>, + mut ctx: ExpressionContext<'a, '_, '_>, + ) -> Result)>, Error<'a>> { + let mut lexer = lexer.clone(); + + let external_function = if let Some(std_namespaces) = self.std_namespace.as_deref() { + std_namespaces.iter().all(|namespace| { + lexer.skip(Token::Word(namespace)) && lexer.skip(Token::DoubleColon) + }) + } else { + false + }; + + let origin = if external_function { + let function = lexer.next_ident()?; + crate::FunctionOrigin::External(function.to_string()) + } else if let Ok(function) = lexer.next_ident() { + if let Some(&function) = self.function_lookup.get(function) { + crate::FunctionOrigin::Local(function) + } else { + return Ok(None); + } + } else { + return Ok(None); + }; + + if !lexer.skip(Token::Paren('(')) { + return Ok(None); + } + + let mut arguments = Vec::new(); + while !lexer.skip(Token::Paren(')')) { + if !arguments.is_empty() { + lexer.expect(Token::Separator(','))?; + } + let arg = self.parse_general_expression(&mut lexer, ctx.reborrow())?; + arguments.push(arg); + } + Ok(Some((crate::Expression::Call { origin, arguments }, lexer))) + } + fn parse_const_expression<'a>( &mut self, lexer: &mut Lexer<'a>, @@ -654,22 +701,11 @@ impl Parser { self.scopes.pop(); return Ok(*handle); } - if self.std_namespace.as_deref() == Some(word) { - lexer.expect(Token::DoubleColon)?; - let name = lexer.next_ident()?; - let mut arguments = Vec::new(); - lexer.expect(Token::Paren('('))?; - while !lexer.skip(Token::Paren(')')) { - if !arguments.is_empty() { - lexer.expect(Token::Separator(','))?; - } - let arg = self.parse_general_expression(lexer, ctx.reborrow())?; - arguments.push(arg); - } - crate::Expression::Call { - origin: crate::FunctionOrigin::External(name.to_owned()), - arguments, - } + if let Some((expr, new_lexer)) = + self.parse_function_call(&backup, ctx.reborrow())? + { + *lexer = new_lexer; + expr } else { *lexer = backup; let ty = self.parse_type_decl(lexer, ctx.types)?; @@ -1295,6 +1331,7 @@ impl Parser { lexer: &mut Lexer<'a>, mut context: StatementContext<'a, '_, '_>, ) -> Result, Error<'a>> { + let backup = lexer.clone(); match lexer.next() { Token::Separator(';') => Ok(Some(crate::Statement::Empty)), Token::Paren('}') => Ok(None), @@ -1387,15 +1424,26 @@ impl Parser { "continue" => crate::Statement::Continue, ident => { // assignment - let var_expr = context.lookup_ident.lookup(ident)?; - let left = self.parse_postfix(lexer, context.as_expression(), var_expr)?; - lexer.expect(Token::Operation('='))?; - let value = - self.parse_general_expression(lexer, context.as_expression())?; - lexer.expect(Token::Separator(';'))?; - crate::Statement::Store { - pointer: left, - value, + if let Some(&var_expr) = context.lookup_ident.get(ident) { + let left = + self.parse_postfix(lexer, context.as_expression(), var_expr)?; + lexer.expect(Token::Operation('='))?; + let value = + self.parse_general_expression(lexer, context.as_expression())?; + lexer.expect(Token::Separator(';'))?; + crate::Statement::Store { + pointer: left, + value, + } + } else if let Some((expr, new_lexer)) = + self.parse_function_call(&backup, context.as_expression())? + { + *lexer = new_lexer; + context.expressions.append(expr); + lexer.expect(Token::Separator(';'))?; + crate::Statement::Empty + } else { + return Err(Error::UnknownIdent(ident)); } } }; @@ -1459,35 +1507,45 @@ impl Parser { } else { Some(self.parse_type_decl(lexer, &mut module.types)?) }; + + let fun_handle = module.functions.append(crate::Function { + name: Some(fun_name.to_string()), + parameter_types, + return_type, + global_usage: Vec::new(), + local_variables: Arena::new(), + expressions, + body: Vec::new(), + }); + if self + .function_lookup + .insert(fun_name.to_string(), fun_handle) + .is_some() + { + return Err(Error::FunctionRedefinition(fun_name)); + } + let fun = module.functions.get_mut(fun_handle); + // read body - let mut local_variables = Arena::new(); let mut typifier = Typifier::new(); - let body = self.parse_block( + fun.body = self.parse_block( lexer, StatementContext { lookup_ident: &mut lookup_ident, typifier: &mut typifier, - variables: &mut local_variables, - expressions: &mut expressions, + variables: &mut fun.local_variables, + expressions: &mut fun.expressions, types: &mut module.types, constants: &mut module.constants, global_vars: &module.global_variables, }, )?; // done - let global_usage = crate::GlobalUse::scan(&expressions, &body, &module.global_variables); + fun.global_usage = + crate::GlobalUse::scan(&fun.expressions, &fun.body, &module.global_variables); self.scopes.pop(); - let fun = crate::Function { - name: Some(fun_name.to_owned()), - parameter_types, - return_type, - global_usage, - local_variables, - expressions, - body, - }; - Ok(module.functions.append(fun)) + Ok(fun_handle) } fn parse_global_decl<'a>( @@ -1554,10 +1612,16 @@ impl Parser { other => return Err(Error::Unexpected(other)), }; lexer.expect(Token::Word("as"))?; - let namespace = lexer.next_ident()?; - lexer.expect(Token::Separator(';'))?; + let mut namespaces = Vec::new(); + loop { + namespaces.push(lexer.next_ident()?.to_owned()); + if lexer.skip(Token::Separator(';')) { + break; + } + lexer.expect(Token::DoubleColon)?; + } match path { - "GLSL.std.450" => self.std_namespace = Some(namespace.to_owned()), + "GLSL.std.450" => self.std_namespace = Some(namespaces), _ => return Err(Error::UnknownImport(path)), } self.scopes.pop(); diff --git a/src/lib.rs b/src/lib.rs index 867cacb9dd..753ebf7fb0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -423,6 +423,10 @@ pub enum DerivativeAxis { #[cfg_attr(feature = "deserialize", derive(Deserialize))] pub enum FunctionOrigin { Local(Handle), + // External { + // namespace: String, // Maybe this should be a handle to a namespace Arena? + // function: String, + // }, External(String), } diff --git a/test-data/function.wgsl b/test-data/function.wgsl new file mode 100644 index 0000000000..7252e1cd1b --- /dev/null +++ b/test-data/function.wgsl @@ -0,0 +1,12 @@ +import "GLSL.std.450" as std::glsl; + +fn test_function(test: f32) -> f32 { + return test; +} + +fn main_vert() -> void { + var foo: f32 = std::glsl::distance(0.0, 1.0); + var test: f32 = test_function(1.0); +} + +entry_point vertex as "main" = main_vert; \ No newline at end of file