diff --git a/src/front/wgsl/conv.rs b/src/front/wgsl/conv.rs index d3bccc7e60..674f2fee9b 100644 --- a/src/front/wgsl/conv.rs +++ b/src/front/wgsl/conv.rs @@ -107,6 +107,7 @@ pub fn get_intrinsic(word: &str) -> Option { _ => None, } } + pub fn get_derivative(word: &str) -> Option { match word { "dpdx" => Some(crate::DerivativeAxis::X), @@ -115,3 +116,11 @@ pub fn get_derivative(word: &str) -> Option { _ => None, } } + +// Returns argument count on success +pub fn get_standard_fun(word: &str) -> Option { + match word { + "min" | "max" => Some(2), + _ => None, + } +} diff --git a/src/front/wgsl/mod.rs b/src/front/wgsl/mod.rs index fe36e24854..98d4cc513e 100644 --- a/src/front/wgsl/mod.rs +++ b/src/front/wgsl/mod.rs @@ -607,49 +607,54 @@ impl Parser { Token::Word(word) => { if let Some(fun) = conv::get_intrinsic(word) { lexer.expect(Token::Paren('('))?; - let argument = self.parse_primary_expression(lexer, ctx.reborrow())?; + let argument = self.parse_general_expression(lexer, ctx.reborrow())?; lexer.expect(Token::Paren(')'))?; Some(crate::Expression::Intrinsic { fun, argument }) } else if let Some(axis) = conv::get_derivative(word) { lexer.expect(Token::Paren('('))?; - let expr = self.parse_primary_expression(lexer, ctx.reborrow())?; + let expr = self.parse_general_expression(lexer, ctx.reborrow())?; lexer.expect(Token::Paren(')'))?; Some(crate::Expression::Derivative { axis, expr }) } else if let Some((kind, _width)) = conv::get_scalar_type(word) { lexer.expect(Token::Paren('('))?; - let expr = self.parse_primary_expression(lexer, ctx.reborrow())?; + let expr = self.parse_general_expression(lexer, ctx.reborrow())?; lexer.expect(Token::Paren(')'))?; Some(crate::Expression::As { expr, kind, convert: true, }) + } else if let Some(arg_count) = conv::get_standard_fun(word) { + lexer.expect(Token::Paren('('))?; + let mut arguments = Vec::with_capacity(arg_count); + for i in 0..arg_count { + let arg = self.parse_general_expression(lexer, ctx.reborrow())?; + arguments.push(arg); + lexer.expect(if i + 1 == arg_count { + Token::Paren(')') + } else { + Token::Separator(',') + })?; + } + Some(crate::Expression::Call { + origin: crate::FunctionOrigin::External(word.to_string()), + arguments, + }) } else { match word { - "min" | "max" => { - lexer.expect(Token::Paren('('))?; - let a = self.parse_primary_expression(lexer, ctx.reborrow())?; - lexer.expect(Token::Separator(','))?; - let b = self.parse_primary_expression(lexer, ctx.reborrow())?; - lexer.expect(Token::Paren(')'))?; - Some(crate::Expression::Call { - origin: crate::FunctionOrigin::External(word.to_string()), - arguments: vec![a, b], - }) - } "dot" => { lexer.expect(Token::Paren('('))?; - let a = self.parse_primary_expression(lexer, ctx.reborrow())?; + let a = self.parse_general_expression(lexer, ctx.reborrow())?; lexer.expect(Token::Separator(','))?; - let b = self.parse_primary_expression(lexer, ctx.reborrow())?; + let b = self.parse_general_expression(lexer, ctx.reborrow())?; lexer.expect(Token::Paren(')'))?; Some(crate::Expression::DotProduct(a, b)) } "cross" => { lexer.expect(Token::Paren('('))?; - let a = self.parse_primary_expression(lexer, ctx.reborrow())?; + let a = self.parse_general_expression(lexer, ctx.reborrow())?; lexer.expect(Token::Separator(','))?; - let b = self.parse_primary_expression(lexer, ctx.reborrow())?; + let b = self.parse_general_expression(lexer, ctx.reborrow())?; lexer.expect(Token::Paren(')'))?; Some(crate::Expression::CrossProduct(a, b)) } @@ -660,7 +665,7 @@ impl Parser { let sampler_name = lexer.next_ident()?; lexer.expect(Token::Separator(','))?; let coordinate = - self.parse_primary_expression(lexer, ctx.reborrow())?; + self.parse_general_expression(lexer, ctx.reborrow())?; lexer.expect(Token::Paren(')'))?; Some(crate::Expression::ImageSample { image: ctx.lookup_ident.lookup(image_name)?, @@ -677,9 +682,9 @@ impl Parser { let sampler_name = lexer.next_ident()?; lexer.expect(Token::Separator(','))?; let coordinate = - self.parse_primary_expression(lexer, ctx.reborrow())?; + self.parse_general_expression(lexer, ctx.reborrow())?; lexer.expect(Token::Separator(','))?; - let level = self.parse_primary_expression(lexer, ctx.reborrow())?; + let level = self.parse_general_expression(lexer, ctx.reborrow())?; lexer.expect(Token::Paren(')'))?; Some(crate::Expression::ImageSample { image: ctx.lookup_ident.lookup(image_name)?, @@ -696,9 +701,9 @@ impl Parser { let sampler_name = lexer.next_ident()?; lexer.expect(Token::Separator(','))?; let coordinate = - self.parse_primary_expression(lexer, ctx.reborrow())?; + self.parse_general_expression(lexer, ctx.reborrow())?; lexer.expect(Token::Separator(','))?; - let bias = self.parse_primary_expression(lexer, ctx.reborrow())?; + let bias = self.parse_general_expression(lexer, ctx.reborrow())?; lexer.expect(Token::Paren(')'))?; Some(crate::Expression::ImageSample { image: ctx.lookup_ident.lookup(image_name)?, @@ -715,9 +720,9 @@ impl Parser { let sampler_name = lexer.next_ident()?; lexer.expect(Token::Separator(','))?; let coordinate = - self.parse_primary_expression(lexer, ctx.reborrow())?; + self.parse_general_expression(lexer, ctx.reborrow())?; lexer.expect(Token::Separator(','))?; - let reference = self.parse_primary_expression(lexer, ctx.reborrow())?; + let reference = self.parse_general_expression(lexer, ctx.reborrow())?; lexer.expect(Token::Paren(')'))?; Some(crate::Expression::ImageSample { image: ctx.lookup_ident.lookup(image_name)?, @@ -733,7 +738,7 @@ impl Parser { let image = ctx.lookup_ident.lookup(image_name)?; lexer.expect(Token::Separator(','))?; let coordinate = - self.parse_primary_expression(lexer, ctx.reborrow())?; + self.parse_general_expression(lexer, ctx.reborrow())?; let is_storage = match *ctx.resolve_type(image)? { crate::TypeInner::Image { class: crate::ImageClass::Storage(_), diff --git a/src/front/wgsl/tests.rs b/src/front/wgsl/tests.rs index ccc34f229e..5d9f4f5a35 100644 --- a/src/front/wgsl/tests.rs +++ b/src/front/wgsl/tests.rs @@ -13,7 +13,19 @@ fn parse_type_cast() { const a : i32 = 2; fn main() { var x: f32 = f32(a); - #x = f32(i32(a + 1) / 2); //TODO + x = f32(i32(a + 1) / 2); + } + ", + ) + .unwrap(); +} + +#[test] +fn parse_standard_fun() { + parse_str( + " + fn main() { + var x: i32 = min(max(1, 2), 3); } ", )