From 7b54f9dfd23548473a4210f606e16884abc6924f Mon Sep 17 00:00:00 2001 From: Jamie Nicol Date: Mon, 3 Mar 2025 14:33:08 +0000 Subject: [PATCH] [naga wgsl-in] Concretize base type prior to non-constant indexed access (#7260) Parsing currently fails for shaders that attempt to dynamically index an abstract-typed array (or vector, etc), like so: var x = array(1, 2, 3)[i]; This is caused by attempting to concretize the Expression::Access expression, which the ConstantEvaluator fails to do so due to the presence of a non-constant expression. To solve this, this patch concretizes the base type *prior* to indexing it (for non-constant indices), meaning the constant evaluator never sees any non-constant expressions. This matches the WGSL specification: When an abstract array value e is indexed by an expression that is not a const-expression, then the array is concretized before the index is applied. (Similar applies for both vectors and matrices, too.) This may be somewhat non-optimal in that if there are multiple accesses of the same abstract expression, we will produce duplicated concretized versions of that expression. This seems unlikely to be a major issue in practice, and we can always improve this if and when we encounter a real issue caused by it. --- naga/src/front/wgsl/lower/mod.rs | 16 +- naga/tests/in/wgsl/const-exprs.wgsl | 15 + .../out/glsl/const-exprs.main.Compute.glsl | 10 + naga/tests/out/hlsl/const-exprs.hlsl | 18 + naga/tests/out/msl/const-exprs.msl | 15 + naga/tests/out/spv/const-exprs.spvasm | 321 ++++++++++-------- naga/tests/out/wgsl/const-exprs.wgsl | 11 + 7 files changed, 258 insertions(+), 148 deletions(-) diff --git a/naga/src/front/wgsl/lower/mod.rs b/naga/src/front/wgsl/lower/mod.rs index 60070284ec..4d8a8c87fe 100644 --- a/naga/src/front/wgsl/lower/mod.rs +++ b/naga/src/front/wgsl/lower/mod.rs @@ -2035,10 +2035,18 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { } } - lowered_base.map(|base| match ctx.const_eval_expr_to_u32(index).ok() { - Some(index) => crate::Expression::AccessIndex { base, index }, - None => crate::Expression::Access { base, index }, - }) + lowered_base.try_map(|base| match ctx.const_eval_expr_to_u32(index).ok() { + Some(index) => Ok::<_, Error>(crate::Expression::AccessIndex { base, index }), + None => { + // When an abstract array value e is indexed by an expression + // that is not a const-expression, then the array is concretized + // before the index is applied. + // https://www.w3.org/TR/WGSL/#array-access-expr + // Also applies to vectors and matrices. + let base = ctx.concretize(base)?; + Ok(crate::Expression::Access { base, index }) + } + })? } ast::Expression::Member { base, ref field } => { let mut lowered_base = self.expression_for_reference(base, ctx)?; diff --git a/naga/tests/in/wgsl/const-exprs.wgsl b/naga/tests/in/wgsl/const-exprs.wgsl index c8ab84cd3f..78884628bb 100644 --- a/naga/tests/in/wgsl/const-exprs.wgsl +++ b/naga/tests/in/wgsl/const-exprs.wgsl @@ -115,3 +115,18 @@ fn test_local_const() { const local_const = 2; var arr: array; } + +const ABSTRACT_ARRAY = array(1, 2, 3, 4, 5, 6, 7, 8, 9); +const ABSTRACT_VECTOR = vec4(1, 2, 3, 4); + +fn abstract_access(i: u32) { + // Constant indexing of abstract types is allowed, therefore we can assign + // to f32 or u32 vars just fine. + var a: f32 = ABSTRACT_ARRAY[0]; + var b: u32 = ABSTRACT_VECTOR.x; + + // For non constant indices the base type is concretized prior to indexing, + // therefore we can only assign to i32 in this case. + var c: i32 = ABSTRACT_ARRAY[i]; + var d: i32 = ABSTRACT_VECTOR[i]; +} diff --git a/naga/tests/out/glsl/const-exprs.main.Compute.glsl b/naga/tests/out/glsl/const-exprs.main.Compute.glsl index e4d50b297a..e301b4c9eb 100644 --- a/naga/tests/out/glsl/const-exprs.main.Compute.glsl +++ b/naga/tests/out/glsl/const-exprs.main.Compute.glsl @@ -112,6 +112,16 @@ void relational() { return; } +void abstract_access(uint i) { + float a_1 = 1.0; + uint b_1 = 1u; + int c_1 = 0; + int d = 0; + c_1 = int[9](1, 2, 3, 4, 5, 6, 7, 8, 9)[i]; + d = ivec4(1, 2, 3, 4)[i]; + return; +} + void main() { swizzle_of_compose(); index_of_compose(); diff --git a/naga/tests/out/hlsl/const-exprs.hlsl b/naga/tests/out/hlsl/const-exprs.hlsl index 40222d54e4..19ffb95ba0 100644 --- a/naga/tests/out/hlsl/const-exprs.hlsl +++ b/naga/tests/out/hlsl/const-exprs.hlsl @@ -125,6 +125,24 @@ void relational() return; } +typedef int ret_Constructarray9_int_[9]; +ret_Constructarray9_int_ Constructarray9_int_(int arg0, int arg1, int arg2, int arg3, int arg4, int arg5, int arg6, int arg7, int arg8) { + int ret[9] = { arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8 }; + return ret; +} + +void abstract_access(uint i) +{ + float a_1 = 1.0; + uint b_1 = 1u; + int c_1 = (int)0; + int d = (int)0; + + c_1 = Constructarray9_int_(int(1), int(2), int(3), int(4), int(5), int(6), int(7), int(8), int(9))[min(uint(i), 8u)]; + d = int4(int(1), int(2), int(3), int(4))[min(uint(i), 3u)]; + return; +} + [numthreads(2, 3, 1)] void main() { diff --git a/naga/tests/out/msl/const-exprs.msl b/naga/tests/out/msl/const-exprs.msl index fd22ef72ac..484d2ed34f 100644 --- a/naga/tests/out/msl/const-exprs.msl +++ b/naga/tests/out/msl/const-exprs.msl @@ -7,6 +7,9 @@ using metal::uint; struct type_6 { float inner[2]; }; +struct type_10 { + int inner[9]; +}; constant uint TWO = 2u; constant int THREE = 3; constant bool TRUE = true; @@ -125,6 +128,18 @@ void relational( return; } +void abstract_access( + uint i +) { + float a_1 = 1.0; + uint b_1 = 1u; + int c_1 = {}; + int d = {}; + c_1 = type_10 {1, 2, 3, 4, 5, 6, 7, 8, 9}.inner[i]; + d = metal::int4(1, 2, 3, 4)[i]; + return; +} + kernel void main_( ) { swizzle_of_compose(); diff --git a/naga/tests/out/spv/const-exprs.spvasm b/naga/tests/out/spv/const-exprs.spvasm index ddd837b2f0..48ae358065 100644 --- a/naga/tests/out/spv/const-exprs.spvasm +++ b/naga/tests/out/spv/const-exprs.spvasm @@ -1,13 +1,14 @@ ; SPIR-V ; Version: 1.1 ; Generator: rspirv -; Bound: 140 +; Bound: 166 OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 -OpEntryPoint GLCompute %130 "main" -OpExecutionMode %130 LocalSize 2 3 1 +OpEntryPoint GLCompute %156 "main" +OpExecutionMode %156 LocalSize 2 3 1 OpDecorate %9 ArrayStride 4 +OpDecorate %14 ArrayStride 4 %2 = OpTypeVoid %3 = OpTypeInt 32 0 %4 = OpTypeInt 32 1 @@ -20,174 +21,206 @@ OpDecorate %9 ArrayStride 4 %11 = OpTypeVector %7 2 %12 = OpTypeVector %5 2 %13 = OpTypeVector %4 3 -%14 = OpConstant %4 3 -%15 = OpConstantTrue %5 -%16 = OpConstantFalse %5 -%17 = OpConstant %4 4 -%18 = OpConstant %4 8 -%19 = OpConstant %7 3.141 -%20 = OpConstant %7 6.282 -%21 = OpConstant %7 0.44444445 -%22 = OpConstant %7 0.0 -%23 = OpConstantComposite %8 %21 %22 %22 %22 -%24 = OpConstant %4 0 -%25 = OpConstant %4 1 -%26 = OpConstant %4 2 -%27 = OpConstant %7 4.0 -%28 = OpConstant %7 5.0 -%29 = OpConstantComposite %11 %27 %28 -%30 = OpConstantComposite %12 %15 %16 -%33 = OpTypeFunction %2 -%34 = OpConstantComposite %6 %17 %14 %26 %25 -%36 = OpTypePointer Function %6 -%41 = OpTypePointer Function %4 -%45 = OpConstant %4 6 -%50 = OpConstant %4 30 -%51 = OpConstant %4 70 -%54 = OpConstantNull %4 +%15 = OpConstant %3 9 +%14 = OpTypeArray %4 %15 +%16 = OpConstant %4 3 +%17 = OpConstantTrue %5 +%18 = OpConstantFalse %5 +%19 = OpConstant %4 4 +%20 = OpConstant %4 8 +%21 = OpConstant %7 3.141 +%22 = OpConstant %7 6.282 +%23 = OpConstant %7 0.44444445 +%24 = OpConstant %7 0.0 +%25 = OpConstantComposite %8 %23 %24 %24 %24 +%26 = OpConstant %4 0 +%27 = OpConstant %4 1 +%28 = OpConstant %4 2 +%29 = OpConstant %7 4.0 +%30 = OpConstant %7 5.0 +%31 = OpConstantComposite %11 %29 %30 +%32 = OpConstantComposite %12 %17 %18 +%35 = OpTypeFunction %2 +%36 = OpConstantComposite %6 %19 %16 %28 %27 +%38 = OpTypePointer Function %6 +%43 = OpTypePointer Function %4 +%47 = OpConstant %4 6 +%52 = OpConstant %4 30 +%53 = OpConstant %4 70 %56 = OpConstantNull %4 -%59 = OpConstantNull %6 -%70 = OpConstant %4 -4 -%71 = OpConstantComposite %6 %70 %70 %70 %70 -%80 = OpConstant %7 1.0 -%81 = OpConstant %7 2.0 -%82 = OpConstantComposite %8 %81 %80 %80 %80 -%84 = OpTypePointer Function %8 -%89 = OpTypePointer Function %9 -%90 = OpConstantNull %9 -%95 = OpTypeFunction %3 %4 -%96 = OpConstant %3 10 -%97 = OpConstant %3 20 -%98 = OpConstant %3 30 -%99 = OpConstant %3 0 -%106 = OpConstantNull %3 -%109 = OpConstantComposite %13 %25 %25 %25 -%110 = OpConstantComposite %13 %24 %25 %26 -%111 = OpConstantComposite %13 %25 %24 %26 -%113 = OpTypePointer Function %13 -%120 = OpTypePointer Function %5 -%32 = OpFunction %2 None %33 -%31 = OpLabel -%35 = OpVariable %36 Function %34 -OpBranch %37 -%37 = OpLabel +%58 = OpConstantNull %4 +%61 = OpConstantNull %6 +%72 = OpConstant %4 -4 +%73 = OpConstantComposite %6 %72 %72 %72 %72 +%82 = OpConstant %7 1.0 +%83 = OpConstant %7 2.0 +%84 = OpConstantComposite %8 %83 %82 %82 %82 +%86 = OpTypePointer Function %8 +%91 = OpTypePointer Function %9 +%92 = OpConstantNull %9 +%97 = OpTypeFunction %3 %4 +%98 = OpConstant %3 10 +%99 = OpConstant %3 20 +%100 = OpConstant %3 30 +%101 = OpConstant %3 0 +%108 = OpConstantNull %3 +%111 = OpConstantComposite %13 %27 %27 %27 +%112 = OpConstantComposite %13 %26 %27 %28 +%113 = OpConstantComposite %13 %27 %26 %28 +%115 = OpTypePointer Function %13 +%122 = OpTypePointer Function %5 +%134 = OpTypeFunction %2 %3 +%135 = OpConstant %3 1 +%136 = OpConstant %4 5 +%137 = OpConstant %4 7 +%138 = OpConstant %4 9 +%139 = OpConstantComposite %14 %27 %28 %16 %19 %136 %47 %137 %20 %138 +%140 = OpConstantComposite %6 %27 %28 %16 %19 +%142 = OpTypePointer Function %7 +%144 = OpTypePointer Function %3 +%146 = OpConstantNull %4 +%148 = OpConstantNull %4 +%150 = OpTypePointer Function %14 +%34 = OpFunction %2 None %35 +%33 = OpLabel +%37 = OpVariable %38 Function %36 +OpBranch %39 +%39 = OpLabel OpReturn OpFunctionEnd -%39 = OpFunction %2 None %33 -%38 = OpLabel -%40 = OpVariable %41 Function %26 -OpBranch %42 -%42 = OpLabel +%41 = OpFunction %2 None %35 +%40 = OpLabel +%42 = OpVariable %43 Function %28 +OpBranch %44 +%44 = OpLabel OpReturn OpFunctionEnd -%44 = OpFunction %2 None %33 -%43 = OpLabel -%46 = OpVariable %41 Function %45 -OpBranch %47 -%47 = OpLabel +%46 = OpFunction %2 None %35 +%45 = OpLabel +%48 = OpVariable %43 Function %47 +OpBranch %49 +%49 = OpLabel OpReturn OpFunctionEnd -%49 = OpFunction %2 None %33 -%48 = OpLabel -%58 = OpVariable %36 Function %59 -%53 = OpVariable %41 Function %54 -%57 = OpVariable %41 Function %51 -%52 = OpVariable %41 Function %50 -%55 = OpVariable %41 Function %56 -OpBranch %60 -%60 = OpLabel -%61 = OpLoad %4 %52 -OpStore %53 %61 -%62 = OpLoad %4 %53 -OpStore %55 %62 -%63 = OpLoad %4 %52 -%64 = OpLoad %4 %53 -%65 = OpLoad %4 %55 -%66 = OpLoad %4 %57 -%67 = OpCompositeConstruct %6 %63 %64 %65 %66 -OpStore %58 %67 +%51 = OpFunction %2 None %35 +%50 = OpLabel +%60 = OpVariable %38 Function %61 +%55 = OpVariable %43 Function %56 +%59 = OpVariable %43 Function %53 +%54 = OpVariable %43 Function %52 +%57 = OpVariable %43 Function %58 +OpBranch %62 +%62 = OpLabel +%63 = OpLoad %4 %54 +OpStore %55 %63 +%64 = OpLoad %4 %55 +OpStore %57 %64 +%65 = OpLoad %4 %54 +%66 = OpLoad %4 %55 +%67 = OpLoad %4 %57 +%68 = OpLoad %4 %59 +%69 = OpCompositeConstruct %6 %65 %66 %67 %68 +OpStore %60 %69 OpReturn OpFunctionEnd -%69 = OpFunction %2 None %33 -%68 = OpLabel -%72 = OpVariable %36 Function %71 -OpBranch %73 -%73 = OpLabel +%71 = OpFunction %2 None %35 +%70 = OpLabel +%74 = OpVariable %38 Function %73 +OpBranch %75 +%75 = OpLabel OpReturn OpFunctionEnd -%75 = OpFunction %2 None %33 -%74 = OpLabel -%76 = OpVariable %36 Function %71 -OpBranch %77 -%77 = OpLabel +%77 = OpFunction %2 None %35 +%76 = OpLabel +%78 = OpVariable %38 Function %73 +OpBranch %79 +%79 = OpLabel OpReturn OpFunctionEnd -%79 = OpFunction %2 None %33 -%78 = OpLabel -%83 = OpVariable %84 Function %82 -OpBranch %85 -%85 = OpLabel +%81 = OpFunction %2 None %35 +%80 = OpLabel +%85 = OpVariable %86 Function %84 +OpBranch %87 +%87 = OpLabel OpReturn OpFunctionEnd -%87 = OpFunction %2 None %33 -%86 = OpLabel -%88 = OpVariable %89 Function %90 -OpBranch %91 -%91 = OpLabel +%89 = OpFunction %2 None %35 +%88 = OpLabel +%90 = OpVariable %91 Function %92 +OpBranch %93 +%93 = OpLabel OpReturn OpFunctionEnd -%94 = OpFunction %3 None %95 -%93 = OpFunctionParameter %4 -%92 = OpLabel -OpBranch %100 -%100 = OpLabel -OpSelectionMerge %101 None -OpSwitch %93 %105 0 %102 1 %103 2 %104 +%96 = OpFunction %3 None %97 +%95 = OpFunctionParameter %4 +%94 = OpLabel +OpBranch %102 %102 = OpLabel -OpReturnValue %96 -%103 = OpLabel -OpReturnValue %97 +OpSelectionMerge %103 None +OpSwitch %95 %107 0 %104 1 %105 2 %106 %104 = OpLabel OpReturnValue %98 %105 = OpLabel OpReturnValue %99 -%101 = OpLabel -OpReturnValue %106 -OpFunctionEnd -%108 = OpFunction %2 None %33 +%106 = OpLabel +OpReturnValue %100 %107 = OpLabel -%112 = OpVariable %113 Function %109 -%114 = OpVariable %113 Function %110 -%115 = OpVariable %113 Function %111 -OpBranch %116 -%116 = OpLabel +OpReturnValue %101 +%103 = OpLabel +OpReturnValue %108 +OpFunctionEnd +%110 = OpFunction %2 None %35 +%109 = OpLabel +%114 = OpVariable %115 Function %111 +%116 = OpVariable %115 Function %112 +%117 = OpVariable %115 Function %113 +OpBranch %118 +%118 = OpLabel OpReturn OpFunctionEnd -%118 = OpFunction %2 None %33 -%117 = OpLabel -%126 = OpVariable %120 Function %16 -%123 = OpVariable %120 Function %15 -%119 = OpVariable %120 Function %16 -%127 = OpVariable %120 Function %15 -%124 = OpVariable %120 Function %16 -%121 = OpVariable %120 Function %15 -%125 = OpVariable %120 Function %15 -%122 = OpVariable %120 Function %16 -OpBranch %128 -%128 = OpLabel +%120 = OpFunction %2 None %35 +%119 = OpLabel +%128 = OpVariable %122 Function %18 +%125 = OpVariable %122 Function %17 +%121 = OpVariable %122 Function %18 +%129 = OpVariable %122 Function %17 +%126 = OpVariable %122 Function %18 +%123 = OpVariable %122 Function %17 +%127 = OpVariable %122 Function %17 +%124 = OpVariable %122 Function %18 +OpBranch %130 +%130 = OpLabel OpReturn OpFunctionEnd -%130 = OpFunction %2 None %33 -%129 = OpLabel -OpBranch %131 +%133 = OpFunction %2 None %134 +%132 = OpFunctionParameter %3 %131 = OpLabel -%132 = OpFunctionCall %2 %32 -%133 = OpFunctionCall %2 %39 -%134 = OpFunctionCall %2 %44 -%135 = OpFunctionCall %2 %49 -%136 = OpFunctionCall %2 %69 -%137 = OpFunctionCall %2 %75 -%138 = OpFunctionCall %2 %79 -%139 = OpFunctionCall %2 %87 +%143 = OpVariable %144 Function %135 +%147 = OpVariable %43 Function %148 +%141 = OpVariable %142 Function %82 +%145 = OpVariable %43 Function %146 +%151 = OpVariable %150 Function +OpBranch %149 +%149 = OpLabel +OpStore %151 %139 +%152 = OpAccessChain %43 %151 %132 +%153 = OpLoad %4 %152 +OpStore %145 %153 +%154 = OpVectorExtractDynamic %4 %140 %132 +OpStore %147 %154 +OpReturn +OpFunctionEnd +%156 = OpFunction %2 None %35 +%155 = OpLabel +OpBranch %157 +%157 = OpLabel +%158 = OpFunctionCall %2 %34 +%159 = OpFunctionCall %2 %41 +%160 = OpFunctionCall %2 %46 +%161 = OpFunctionCall %2 %51 +%162 = OpFunctionCall %2 %71 +%163 = OpFunctionCall %2 %77 +%164 = OpFunctionCall %2 %81 +%165 = OpFunctionCall %2 %89 OpReturn OpFunctionEnd \ No newline at end of file diff --git a/naga/tests/out/wgsl/const-exprs.wgsl b/naga/tests/out/wgsl/const-exprs.wgsl index 1e94167c24..6d13106948 100644 --- a/naga/tests/out/wgsl/const-exprs.wgsl +++ b/naga/tests/out/wgsl/const-exprs.wgsl @@ -114,6 +114,17 @@ fn relational() { return; } +fn abstract_access(i: u32) { + var a_1: f32 = 1f; + var b_1: u32 = 1u; + var c_1: i32; + var d: i32; + + c_1 = array(1i, 2i, 3i, 4i, 5i, 6i, 7i, 8i, 9i)[i]; + d = vec4(1i, 2i, 3i, 4i)[i]; + return; +} + @compute @workgroup_size(2, 3, 1) fn main() { swizzle_of_compose();