From 960551f952015cb92c16d6382359252d32cd7bc2 Mon Sep 17 00:00:00 2001 From: Lachlan Sneff Date: Fri, 30 Apr 2021 00:20:09 -0400 Subject: [PATCH] Add support for arrayLength to the wgsl frontend (#805) * Add support for arrayLength to the wgsl frontend * Fix clippy warning --- src/front/wgsl/mod.rs | 5 +++++ src/front/wgsl/tests.rs | 24 ++++++++++++++++++++++++ src/valid/expression.rs | 12 +++++++++++- 3 files changed, 40 insertions(+), 1 deletion(-) diff --git a/src/front/wgsl/mod.rs b/src/front/wgsl/mod.rs index d2ac7a0b10..f24eea8fd1 100644 --- a/src/front/wgsl/mod.rs +++ b/src/front/wgsl/mod.rs @@ -741,6 +741,11 @@ impl Parser { accept, reject, } + } else if name == "arrayLength" { + lexer.open_arguments()?; + let array = self.parse_singular_expression(lexer, ctx.reborrow())?; + lexer.close_arguments()?; + crate::Expression::ArrayLength(array) } else { // texture sampling match name { diff --git a/src/front/wgsl/tests.rs b/src/front/wgsl/tests.rs index d7428b65b0..a537c62d1d 100644 --- a/src/front/wgsl/tests.rs +++ b/src/front/wgsl/tests.rs @@ -304,3 +304,27 @@ fn parse_struct_instantiation() { ) .unwrap(); } + +#[test] +fn parse_array_length() { + parse_str( + " + [[block]] + struct Foo { + data: [[stride(4)]] array; + }; // this is used as both input and output for convenience + + [[group(0), binding(0)]] + var foo: [[access(read_write)]] Foo; + + [[group(0), binding(1)]] + var bar: [[access(read)]] array; + + fn foo() { + var x: u32 = arrayLength(foo.data); + var y: u32 = arrayLength(bar); + } + ", + ) + .unwrap(); +} diff --git a/src/valid/expression.rs b/src/valid/expression.rs index df79a3dff7..1089b057e9 100644 --- a/src/valid/expression.rs +++ b/src/valid/expression.rs @@ -1110,7 +1110,17 @@ impl super::Validator { } E::Call(function) => other_infos[function.index()].available_stages, E::ArrayLength(expr) => match *resolver.resolve(expr)? { - Ti::Array { .. } => ShaderStages::all(), + Ti::Pointer { base, .. } => { + if let Some(&Ti::Array { + size: crate::ArraySize::Dynamic, + .. + }) = resolver.types.try_get(base).map(|ty| &ty.inner) + { + ShaderStages::all() + } else { + return Err(ExpressionError::InvalidArrayType(expr)); + } + } ref other => { log::error!("Array length of {:?}", other); return Err(ExpressionError::InvalidArrayType(expr));