From ba19d8df34632b5bf068b1ca461ccd8c0cc43802 Mon Sep 17 00:00:00 2001 From: Jim Blandy Date: Mon, 25 Mar 2024 19:11:23 -0700 Subject: [PATCH] [naga] Adjust RayQuery statements in override processing. --- naga/src/back/pipeline_constants.rs | 20 +- naga/tests/in/overrides-ray-query.param.ron | 18 ++ naga/tests/in/overrides-ray-query.wgsl | 21 ++ .../out/ir/overrides-ray-query.compact.ron | 259 ++++++++++++++++++ naga/tests/out/ir/overrides-ray-query.ron | 259 ++++++++++++++++++ naga/tests/out/msl/overrides-ray-query.msl | 45 +++ .../out/spv/overrides-ray-query.main.spvasm | 77 ++++++ naga/tests/snapshots.rs | 5 + 8 files changed, 702 insertions(+), 2 deletions(-) create mode 100644 naga/tests/in/overrides-ray-query.param.ron create mode 100644 naga/tests/in/overrides-ray-query.wgsl create mode 100644 naga/tests/out/ir/overrides-ray-query.compact.ron create mode 100644 naga/tests/out/ir/overrides-ray-query.ron create mode 100644 naga/tests/out/msl/overrides-ray-query.msl create mode 100644 naga/tests/out/spv/overrides-ray-query.main.spvasm diff --git a/naga/src/back/pipeline_constants.rs b/naga/src/back/pipeline_constants.rs index d41eeedef2..c1fd2d02cc 100644 --- a/naga/src/back/pipeline_constants.rs +++ b/naga/src/back/pipeline_constants.rs @@ -633,7 +633,7 @@ fn adjust_stmt(new_pos: &[Handle], stmt: &mut Statement) { Statement::Call { ref mut arguments, ref mut result, - .. + function: _, } => { for argument in arguments.iter_mut() { adjust(argument); @@ -642,8 +642,24 @@ fn adjust_stmt(new_pos: &[Handle], stmt: &mut Statement) { adjust(e); } } - Statement::RayQuery { ref mut query, .. } => { + Statement::RayQuery { + ref mut query, + ref mut fun, + } => { adjust(query); + match *fun { + crate::RayQueryFunction::Initialize { + ref mut acceleration_structure, + ref mut descriptor, + } => { + adjust(acceleration_structure); + adjust(descriptor); + } + crate::RayQueryFunction::Proceed { ref mut result } => { + adjust(result); + } + crate::RayQueryFunction::Terminate => {} + } } Statement::Break | Statement::Continue | Statement::Kill | Statement::Barrier(_) => {} } diff --git a/naga/tests/in/overrides-ray-query.param.ron b/naga/tests/in/overrides-ray-query.param.ron new file mode 100644 index 0000000000..588656aaac --- /dev/null +++ b/naga/tests/in/overrides-ray-query.param.ron @@ -0,0 +1,18 @@ +( + god_mode: true, + spv: ( + version: (1, 4), + separate_entry_points: true, + ), + msl: ( + lang_version: (2, 4), + spirv_cross_compatibility: false, + fake_missing_bindings: true, + zero_initialize_workgroup_memory: false, + per_entry_point_map: {}, + inline_samplers: [], + ), + pipeline_constants: { + "o": 2.0 + } +) diff --git a/naga/tests/in/overrides-ray-query.wgsl b/naga/tests/in/overrides-ray-query.wgsl new file mode 100644 index 0000000000..dca7447ed0 --- /dev/null +++ b/naga/tests/in/overrides-ray-query.wgsl @@ -0,0 +1,21 @@ +override o: f32; + +@group(0) @binding(0) +var acc_struct: acceleration_structure; + +@compute @workgroup_size(1) +fn main() { + var rq: ray_query; + + let desc = RayDesc( + RAY_FLAG_TERMINATE_ON_FIRST_HIT, + 0xFFu, + o * 17.0, + o * 19.0, + vec3(o * 23.0), + vec3(o * 29.0, o * 31.0, o * 37.0), + ); + rayQueryInitialize(&rq, acc_struct, desc); + + while (rayQueryProceed(&rq)) {} +} diff --git a/naga/tests/out/ir/overrides-ray-query.compact.ron b/naga/tests/out/ir/overrides-ray-query.compact.ron new file mode 100644 index 0000000000..b127259bbb --- /dev/null +++ b/naga/tests/out/ir/overrides-ray-query.compact.ron @@ -0,0 +1,259 @@ +( + types: [ + ( + name: None, + inner: Scalar(( + kind: Float, + width: 4, + )), + ), + ( + name: None, + inner: AccelerationStructure, + ), + ( + name: None, + inner: RayQuery, + ), + ( + name: None, + inner: Scalar(( + kind: Uint, + width: 4, + )), + ), + ( + name: None, + inner: Vector( + size: Tri, + scalar: ( + kind: Float, + width: 4, + ), + ), + ), + ( + name: Some("RayDesc"), + inner: Struct( + members: [ + ( + name: Some("flags"), + ty: 4, + binding: None, + offset: 0, + ), + ( + name: Some("cull_mask"), + ty: 4, + binding: None, + offset: 4, + ), + ( + name: Some("tmin"), + ty: 1, + binding: None, + offset: 8, + ), + ( + name: Some("tmax"), + ty: 1, + binding: None, + offset: 12, + ), + ( + name: Some("origin"), + ty: 5, + binding: None, + offset: 16, + ), + ( + name: Some("dir"), + ty: 5, + binding: None, + offset: 32, + ), + ], + span: 48, + ), + ), + ], + special_types: ( + ray_desc: Some(6), + ray_intersection: None, + predeclared_types: {}, + ), + constants: [], + overrides: [ + ( + name: Some("o"), + id: None, + ty: 1, + init: None, + ), + ], + global_variables: [ + ( + name: Some("acc_struct"), + space: Handle, + binding: Some(( + group: 0, + binding: 0, + )), + ty: 2, + init: None, + ), + ], + global_expressions: [], + functions: [], + entry_points: [ + ( + name: "main", + stage: Compute, + early_depth_test: None, + workgroup_size: (1, 1, 1), + function: ( + name: Some("main"), + arguments: [], + result: None, + local_variables: [ + ( + name: Some("rq"), + ty: 3, + init: None, + ), + ], + expressions: [ + LocalVariable(1), + Literal(U32(4)), + Literal(U32(255)), + Override(1), + Literal(F32(17.0)), + Binary( + op: Multiply, + left: 4, + right: 5, + ), + Override(1), + Literal(F32(19.0)), + Binary( + op: Multiply, + left: 7, + right: 8, + ), + Override(1), + Literal(F32(23.0)), + Binary( + op: Multiply, + left: 10, + right: 11, + ), + Splat( + size: Tri, + value: 12, + ), + Override(1), + Literal(F32(29.0)), + Binary( + op: Multiply, + left: 14, + right: 15, + ), + Override(1), + Literal(F32(31.0)), + Binary( + op: Multiply, + left: 17, + right: 18, + ), + Override(1), + Literal(F32(37.0)), + Binary( + op: Multiply, + left: 20, + right: 21, + ), + Compose( + ty: 5, + components: [ + 16, + 19, + 22, + ], + ), + Compose( + ty: 6, + components: [ + 2, + 3, + 6, + 9, + 13, + 23, + ], + ), + GlobalVariable(1), + RayQueryProceedResult, + ], + named_expressions: { + 24: "desc", + }, + body: [ + Emit(( + start: 5, + end: 6, + )), + Emit(( + start: 8, + end: 9, + )), + Emit(( + start: 11, + end: 13, + )), + Emit(( + start: 15, + end: 16, + )), + Emit(( + start: 18, + end: 19, + )), + Emit(( + start: 21, + end: 24, + )), + RayQuery( + query: 1, + fun: Initialize( + acceleration_structure: 25, + descriptor: 24, + ), + ), + Loop( + body: [ + RayQuery( + query: 1, + fun: Proceed( + result: 26, + ), + ), + If( + condition: 26, + accept: [], + reject: [ + Break, + ], + ), + Block([]), + ], + continuing: [], + break_if: None, + ), + Return( + value: None, + ), + ], + ), + ), + ], +) \ No newline at end of file diff --git a/naga/tests/out/ir/overrides-ray-query.ron b/naga/tests/out/ir/overrides-ray-query.ron new file mode 100644 index 0000000000..b127259bbb --- /dev/null +++ b/naga/tests/out/ir/overrides-ray-query.ron @@ -0,0 +1,259 @@ +( + types: [ + ( + name: None, + inner: Scalar(( + kind: Float, + width: 4, + )), + ), + ( + name: None, + inner: AccelerationStructure, + ), + ( + name: None, + inner: RayQuery, + ), + ( + name: None, + inner: Scalar(( + kind: Uint, + width: 4, + )), + ), + ( + name: None, + inner: Vector( + size: Tri, + scalar: ( + kind: Float, + width: 4, + ), + ), + ), + ( + name: Some("RayDesc"), + inner: Struct( + members: [ + ( + name: Some("flags"), + ty: 4, + binding: None, + offset: 0, + ), + ( + name: Some("cull_mask"), + ty: 4, + binding: None, + offset: 4, + ), + ( + name: Some("tmin"), + ty: 1, + binding: None, + offset: 8, + ), + ( + name: Some("tmax"), + ty: 1, + binding: None, + offset: 12, + ), + ( + name: Some("origin"), + ty: 5, + binding: None, + offset: 16, + ), + ( + name: Some("dir"), + ty: 5, + binding: None, + offset: 32, + ), + ], + span: 48, + ), + ), + ], + special_types: ( + ray_desc: Some(6), + ray_intersection: None, + predeclared_types: {}, + ), + constants: [], + overrides: [ + ( + name: Some("o"), + id: None, + ty: 1, + init: None, + ), + ], + global_variables: [ + ( + name: Some("acc_struct"), + space: Handle, + binding: Some(( + group: 0, + binding: 0, + )), + ty: 2, + init: None, + ), + ], + global_expressions: [], + functions: [], + entry_points: [ + ( + name: "main", + stage: Compute, + early_depth_test: None, + workgroup_size: (1, 1, 1), + function: ( + name: Some("main"), + arguments: [], + result: None, + local_variables: [ + ( + name: Some("rq"), + ty: 3, + init: None, + ), + ], + expressions: [ + LocalVariable(1), + Literal(U32(4)), + Literal(U32(255)), + Override(1), + Literal(F32(17.0)), + Binary( + op: Multiply, + left: 4, + right: 5, + ), + Override(1), + Literal(F32(19.0)), + Binary( + op: Multiply, + left: 7, + right: 8, + ), + Override(1), + Literal(F32(23.0)), + Binary( + op: Multiply, + left: 10, + right: 11, + ), + Splat( + size: Tri, + value: 12, + ), + Override(1), + Literal(F32(29.0)), + Binary( + op: Multiply, + left: 14, + right: 15, + ), + Override(1), + Literal(F32(31.0)), + Binary( + op: Multiply, + left: 17, + right: 18, + ), + Override(1), + Literal(F32(37.0)), + Binary( + op: Multiply, + left: 20, + right: 21, + ), + Compose( + ty: 5, + components: [ + 16, + 19, + 22, + ], + ), + Compose( + ty: 6, + components: [ + 2, + 3, + 6, + 9, + 13, + 23, + ], + ), + GlobalVariable(1), + RayQueryProceedResult, + ], + named_expressions: { + 24: "desc", + }, + body: [ + Emit(( + start: 5, + end: 6, + )), + Emit(( + start: 8, + end: 9, + )), + Emit(( + start: 11, + end: 13, + )), + Emit(( + start: 15, + end: 16, + )), + Emit(( + start: 18, + end: 19, + )), + Emit(( + start: 21, + end: 24, + )), + RayQuery( + query: 1, + fun: Initialize( + acceleration_structure: 25, + descriptor: 24, + ), + ), + Loop( + body: [ + RayQuery( + query: 1, + fun: Proceed( + result: 26, + ), + ), + If( + condition: 26, + accept: [], + reject: [ + Break, + ], + ), + Block([]), + ], + continuing: [], + break_if: None, + ), + Return( + value: None, + ), + ], + ), + ), + ], +) \ No newline at end of file diff --git a/naga/tests/out/msl/overrides-ray-query.msl b/naga/tests/out/msl/overrides-ray-query.msl new file mode 100644 index 0000000000..3a508b6f61 --- /dev/null +++ b/naga/tests/out/msl/overrides-ray-query.msl @@ -0,0 +1,45 @@ +// language: metal2.4 +#include +#include + +using metal::uint; +struct _RayQuery { + metal::raytracing::intersector intersector; + metal::raytracing::intersector::result_type intersection; + bool ready = false; +}; +constexpr metal::uint _map_intersection_type(const metal::raytracing::intersection_type ty) { + return ty==metal::raytracing::intersection_type::triangle ? 1 : + ty==metal::raytracing::intersection_type::bounding_box ? 4 : 0; +} + +struct RayDesc { + uint flags; + uint cull_mask; + float tmin; + float tmax; + metal::float3 origin; + metal::float3 dir; +}; +constant float o = 2.0; + +kernel void main_( + metal::raytracing::instance_acceleration_structure acc_struct [[user(fake0)]] +) { + _RayQuery rq = {}; + RayDesc desc = RayDesc {4u, 255u, 34.0, 38.0, metal::float3(46.0), metal::float3(58.0, 62.0, 74.0)}; + rq.intersector.assume_geometry_type(metal::raytracing::geometry_type::triangle); + rq.intersector.set_opacity_cull_mode((desc.flags & 64) != 0 ? metal::raytracing::opacity_cull_mode::opaque : (desc.flags & 128) != 0 ? metal::raytracing::opacity_cull_mode::non_opaque : metal::raytracing::opacity_cull_mode::none); + rq.intersector.force_opacity((desc.flags & 1) != 0 ? metal::raytracing::forced_opacity::opaque : (desc.flags & 2) != 0 ? metal::raytracing::forced_opacity::non_opaque : metal::raytracing::forced_opacity::none); + rq.intersector.accept_any_intersection((desc.flags & 4) != 0); + rq.intersection = rq.intersector.intersect(metal::raytracing::ray(desc.origin, desc.dir, desc.tmin, desc.tmax), acc_struct, desc.cull_mask); rq.ready = true; + while(true) { + bool _e31 = rq.ready; + rq.ready = false; + if (_e31) { + } else { + break; + } + } + return; +} diff --git a/naga/tests/out/spv/overrides-ray-query.main.spvasm b/naga/tests/out/spv/overrides-ray-query.main.spvasm new file mode 100644 index 0000000000..a341393468 --- /dev/null +++ b/naga/tests/out/spv/overrides-ray-query.main.spvasm @@ -0,0 +1,77 @@ +; SPIR-V +; Version: 1.4 +; Generator: rspirv +; Bound: 46 +OpCapability Shader +OpCapability RayQueryKHR +OpExtension "SPV_KHR_ray_query" +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %13 "main" %10 +OpExecutionMode %13 LocalSize 1 1 1 +OpMemberDecorate %8 0 Offset 0 +OpMemberDecorate %8 1 Offset 4 +OpMemberDecorate %8 2 Offset 8 +OpMemberDecorate %8 3 Offset 12 +OpMemberDecorate %8 4 Offset 16 +OpMemberDecorate %8 5 Offset 32 +OpDecorate %10 DescriptorSet 0 +OpDecorate %10 Binding 0 +%2 = OpTypeVoid +%3 = OpTypeFloat 32 +%4 = OpTypeAccelerationStructureNV +%5 = OpTypeRayQueryKHR +%6 = OpTypeInt 32 0 +%7 = OpTypeVector %3 3 +%8 = OpTypeStruct %6 %6 %3 %3 %7 %7 +%9 = OpConstant %3 2.0 +%11 = OpTypePointer UniformConstant %4 +%10 = OpVariable %11 UniformConstant +%14 = OpTypeFunction %2 +%16 = OpConstant %6 4 +%17 = OpConstant %6 255 +%18 = OpConstant %3 34.0 +%19 = OpConstant %3 38.0 +%20 = OpConstant %3 46.0 +%21 = OpConstantComposite %7 %20 %20 %20 +%22 = OpConstant %3 58.0 +%23 = OpConstant %3 62.0 +%24 = OpConstant %3 74.0 +%25 = OpConstantComposite %7 %22 %23 %24 +%26 = OpConstantComposite %8 %16 %17 %18 %19 %21 %25 +%28 = OpTypePointer Function %5 +%41 = OpTypeBool +%13 = OpFunction %2 None %14 +%12 = OpLabel +%27 = OpVariable %28 Function +%15 = OpLoad %4 %10 +OpBranch %29 +%29 = OpLabel +%30 = OpCompositeExtract %6 %26 0 +%31 = OpCompositeExtract %6 %26 1 +%32 = OpCompositeExtract %3 %26 2 +%33 = OpCompositeExtract %3 %26 3 +%34 = OpCompositeExtract %7 %26 4 +%35 = OpCompositeExtract %7 %26 5 +OpRayQueryInitializeKHR %27 %15 %30 %31 %34 %32 %35 %33 +OpBranch %36 +%36 = OpLabel +OpLoopMerge %37 %39 None +OpBranch %38 +%38 = OpLabel +%40 = OpRayQueryProceedKHR %41 %27 +OpSelectionMerge %42 None +OpBranchConditional %40 %42 %43 +%43 = OpLabel +OpBranch %37 +%42 = OpLabel +OpBranch %44 +%44 = OpLabel +OpBranch %45 +%45 = OpLabel +OpBranch %39 +%39 = OpLabel +OpBranch %36 +%37 = OpLabel +OpReturn +OpFunctionEnd \ No newline at end of file diff --git a/naga/tests/snapshots.rs b/naga/tests/snapshots.rs index 151e8b3da3..94c50c7975 100644 --- a/naga/tests/snapshots.rs +++ b/naga/tests/snapshots.rs @@ -466,6 +466,7 @@ fn write_output_spv( ); } } else { + assert!(pipeline_constants.is_empty()); write_output_spv_inner(input, module, info, &options, None, "spvasm"); } } @@ -857,6 +858,10 @@ fn convert_wgsl() { "overrides-atomicCompareExchangeWeak", Targets::IR | Targets::SPIRV, ), + ( + "overrides-ray-query", + Targets::IR | Targets::SPIRV | Targets::METAL, + ), ]; for &(name, targets) in inputs.iter() {