From ab9c36441fc80f73394cf8152b52218860f5eafb Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Tue, 21 Feb 2023 23:31:32 -0800 Subject: [PATCH] fill up the ray query intersection struct --- src/back/spv/block.rs | 154 +------------------ src/back/spv/mod.rs | 1 + src/back/spv/ray.rs | 273 +++++++++++++++++++++++++++++++++ src/front/type_gen.rs | 91 ++++++++++- tests/in/ray-query.wgsl | 3 +- tests/out/spv/ray-query.spvasm | 122 +++++++++------ 6 files changed, 440 insertions(+), 204 deletions(-) create mode 100644 src/back/spv/ray.rs diff --git a/src/back/spv/block.rs b/src/back/spv/block.rs index 77d7267fd8..b28b94fe91 100644 --- a/src/back/spv/block.rs +++ b/src/back/spv/block.rs @@ -1387,59 +1387,10 @@ impl<'w> BlockContext<'w> { } crate::Expression::ArrayLength(expr) => self.write_runtime_array_length(expr, block)?, crate::Expression::RayQueryGetIntersection { query, committed } => { - let width = 4; - let query_id = self.cached[query]; - let intersection_id = self.writer.get_constant_scalar( - crate::ScalarValue::Uint( - spirv::RayQueryIntersection::RayQueryCommittedIntersectionKHR as _, - ), - width, - ); if !committed { return Err(Error::FeatureNotImplemented("candidate intersection")); } - - let flag_type_id = self.get_type_id(LookupType::Local(LocalType::Value { - vector_size: None, - kind: crate::ScalarKind::Uint, - width, - pointer_space: None, - })); - let kind_id = self.gen_id(); - block.body.push(Instruction::ray_query_get_intersection( - spirv::Op::RayQueryGetIntersectionTypeKHR, - flag_type_id, - kind_id, - query_id, - intersection_id, - )); - - let scalar_type_id = self.get_type_id(LookupType::Local(LocalType::Value { - vector_size: None, - kind: crate::ScalarKind::Float, - width, - pointer_space: None, - })); - let t_id = self.gen_id(); - block.body.push(Instruction::ray_query_get_intersection( - spirv::Op::RayQueryGetIntersectionTKHR, - scalar_type_id, - t_id, - query_id, - intersection_id, - )); - - let id = self.gen_id(); - let intersection_type_id = self.get_type_id(LookupType::Handle( - self.ir_module.special_types.ray_intersection.unwrap(), - )); - //Note: the arguments must match `generate_ray_intersection_type` layout - block.body.push(Instruction::composite_construct( - intersection_type_id, - id, - &[kind_id, t_id], - )); - id + self.write_ray_query_get_intersection(query, block) } }; @@ -2252,108 +2203,7 @@ impl<'w> BlockContext<'w> { block.body.push(instruction); } crate::Statement::RayQuery { query, ref fun } => { - let query_id = self.cached[query]; - match *fun { - crate::RayQueryFunction::Initialize { - acceleration_structure, - descriptor, - } => { - //Note: composite extract indices and types must match `generate_ray_desc_type` - let desc_id = self.cached[descriptor]; - let acc_struct_id = self.get_image_id(acceleration_structure); - let width = 4; - - let flag_type_id = - self.get_type_id(LookupType::Local(LocalType::Value { - vector_size: None, - kind: crate::ScalarKind::Uint, - width, - pointer_space: None, - })); - let ray_flags_id = self.gen_id(); - block.body.push(Instruction::composite_extract( - flag_type_id, - ray_flags_id, - desc_id, - &[0], - )); - let cull_mask_id = self.gen_id(); - block.body.push(Instruction::composite_extract( - flag_type_id, - cull_mask_id, - desc_id, - &[1], - )); - - let scalar_type_id = - self.get_type_id(LookupType::Local(LocalType::Value { - vector_size: None, - kind: crate::ScalarKind::Float, - width, - pointer_space: None, - })); - let tmin_id = self.gen_id(); - block.body.push(Instruction::composite_extract( - scalar_type_id, - tmin_id, - desc_id, - &[2], - )); - let tmax_id = self.gen_id(); - block.body.push(Instruction::composite_extract( - scalar_type_id, - tmax_id, - desc_id, - &[3], - )); - - let vector_type_id = - self.get_type_id(LookupType::Local(LocalType::Value { - vector_size: Some(crate::VectorSize::Tri), - kind: crate::ScalarKind::Float, - width, - pointer_space: None, - })); - let ray_origin_id = self.gen_id(); - block.body.push(Instruction::composite_extract( - vector_type_id, - ray_origin_id, - desc_id, - &[4], - )); - let ray_dir_id = self.gen_id(); - block.body.push(Instruction::composite_extract( - vector_type_id, - ray_dir_id, - desc_id, - &[5], - )); - - block.body.push(Instruction::ray_query_initialize( - query_id, - acc_struct_id, - ray_flags_id, - cull_mask_id, - ray_origin_id, - tmin_id, - ray_dir_id, - tmax_id, - )); - } - crate::RayQueryFunction::Proceed { result } => { - let id = self.gen_id(); - self.cached[result] = id; - let result_type_id = - self.get_expression_type_id(&self.fun_info[result].ty); - - block.body.push(Instruction::ray_query_proceed( - result_type_id, - id, - query_id, - )); - } - crate::RayQueryFunction::Terminate => {} - } + self.write_ray_query_function(query, fun, &mut block); } } } diff --git a/src/back/spv/mod.rs b/src/back/spv/mod.rs index 5fba4f0dea..9b084911b1 100644 --- a/src/back/spv/mod.rs +++ b/src/back/spv/mod.rs @@ -10,6 +10,7 @@ mod image; mod index; mod instructions; mod layout; +mod ray; mod recyclable; mod selection; mod writer; diff --git a/src/back/spv/ray.rs b/src/back/spv/ray.rs new file mode 100644 index 0000000000..79eb2ff971 --- /dev/null +++ b/src/back/spv/ray.rs @@ -0,0 +1,273 @@ +/*! +Generating SPIR-V for ray query operations. +*/ + +use super::{Block, BlockContext, Instruction, LocalType, LookupType}; +use crate::arena::Handle; + +impl<'w> BlockContext<'w> { + pub(super) fn write_ray_query_function( + &mut self, + query: Handle, + function: &crate::RayQueryFunction, + block: &mut Block, + ) { + let query_id = self.cached[query]; + match *function { + crate::RayQueryFunction::Initialize { + acceleration_structure, + descriptor, + } => { + //Note: composite extract indices and types must match `generate_ray_desc_type` + let desc_id = self.cached[descriptor]; + let acc_struct_id = self.get_image_id(acceleration_structure); + let width = 4; + + let flag_type_id = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: None, + kind: crate::ScalarKind::Uint, + width, + pointer_space: None, + })); + let ray_flags_id = self.gen_id(); + block.body.push(Instruction::composite_extract( + flag_type_id, + ray_flags_id, + desc_id, + &[0], + )); + let cull_mask_id = self.gen_id(); + block.body.push(Instruction::composite_extract( + flag_type_id, + cull_mask_id, + desc_id, + &[1], + )); + + let scalar_type_id = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: None, + kind: crate::ScalarKind::Float, + width, + pointer_space: None, + })); + let tmin_id = self.gen_id(); + block.body.push(Instruction::composite_extract( + scalar_type_id, + tmin_id, + desc_id, + &[2], + )); + let tmax_id = self.gen_id(); + block.body.push(Instruction::composite_extract( + scalar_type_id, + tmax_id, + desc_id, + &[3], + )); + + let vector_type_id = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: Some(crate::VectorSize::Tri), + kind: crate::ScalarKind::Float, + width, + pointer_space: None, + })); + let ray_origin_id = self.gen_id(); + block.body.push(Instruction::composite_extract( + vector_type_id, + ray_origin_id, + desc_id, + &[4], + )); + let ray_dir_id = self.gen_id(); + block.body.push(Instruction::composite_extract( + vector_type_id, + ray_dir_id, + desc_id, + &[5], + )); + + block.body.push(Instruction::ray_query_initialize( + query_id, + acc_struct_id, + ray_flags_id, + cull_mask_id, + ray_origin_id, + tmin_id, + ray_dir_id, + tmax_id, + )); + } + crate::RayQueryFunction::Proceed { result } => { + let id = self.gen_id(); + self.cached[result] = id; + let result_type_id = self.get_expression_type_id(&self.fun_info[result].ty); + + block + .body + .push(Instruction::ray_query_proceed(result_type_id, id, query_id)); + } + crate::RayQueryFunction::Terminate => {} + } + } + + pub(super) fn write_ray_query_get_intersection( + &mut self, + query: Handle, + block: &mut Block, + ) -> spirv::Word { + let width = 4; + let query_id = self.cached[query]; + let intersection_id = self.writer.get_constant_scalar( + crate::ScalarValue::Uint( + spirv::RayQueryIntersection::RayQueryCommittedIntersectionKHR as _, + ), + width, + ); + + let flag_type_id = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: None, + kind: crate::ScalarKind::Uint, + width, + pointer_space: None, + })); + let kind_id = self.gen_id(); + block.body.push(Instruction::ray_query_get_intersection( + spirv::Op::RayQueryGetIntersectionTypeKHR, + flag_type_id, + kind_id, + query_id, + intersection_id, + )); + let instance_custom_index_id = self.gen_id(); + block.body.push(Instruction::ray_query_get_intersection( + spirv::Op::RayQueryGetIntersectionInstanceCustomIndexKHR, + flag_type_id, + instance_custom_index_id, + query_id, + intersection_id, + )); + let instance_id = self.gen_id(); + block.body.push(Instruction::ray_query_get_intersection( + spirv::Op::RayQueryGetIntersectionInstanceIdKHR, + flag_type_id, + instance_id, + query_id, + intersection_id, + )); + let sbt_record_offset_id = self.gen_id(); + block.body.push(Instruction::ray_query_get_intersection( + spirv::Op::RayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR, + flag_type_id, + sbt_record_offset_id, + query_id, + intersection_id, + )); + let geometry_index_id = self.gen_id(); + block.body.push(Instruction::ray_query_get_intersection( + spirv::Op::RayQueryGetIntersectionGeometryIndexKHR, + flag_type_id, + geometry_index_id, + query_id, + intersection_id, + )); + let primitive_index_id = self.gen_id(); + block.body.push(Instruction::ray_query_get_intersection( + spirv::Op::RayQueryGetIntersectionPrimitiveIndexKHR, + flag_type_id, + primitive_index_id, + query_id, + intersection_id, + )); + + let scalar_type_id = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: None, + kind: crate::ScalarKind::Float, + width, + pointer_space: None, + })); + let t_id = self.gen_id(); + block.body.push(Instruction::ray_query_get_intersection( + spirv::Op::RayQueryGetIntersectionTKHR, + scalar_type_id, + t_id, + query_id, + intersection_id, + )); + + let barycentrics_type_id = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: Some(crate::VectorSize::Bi), + kind: crate::ScalarKind::Float, + width, + pointer_space: None, + })); + let barycentrics_id = self.gen_id(); + block.body.push(Instruction::ray_query_get_intersection( + spirv::Op::RayQueryGetIntersectionBarycentricsKHR, + barycentrics_type_id, + barycentrics_id, + query_id, + intersection_id, + )); + + let bool_type_id = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: None, + kind: crate::ScalarKind::Bool, + width: crate::BOOL_WIDTH, + pointer_space: None, + })); + let front_face_id = self.gen_id(); + block.body.push(Instruction::ray_query_get_intersection( + spirv::Op::RayQueryGetIntersectionFrontFaceKHR, + bool_type_id, + front_face_id, + query_id, + intersection_id, + )); + + let transform_type_id = self.get_type_id(LookupType::Local(LocalType::Matrix { + columns: crate::VectorSize::Quad, + rows: crate::VectorSize::Tri, + width, + })); + let object_to_world_id = self.gen_id(); + block.body.push(Instruction::ray_query_get_intersection( + spirv::Op::RayQueryGetIntersectionObjectToWorldKHR, + transform_type_id, + object_to_world_id, + query_id, + intersection_id, + )); + let world_to_object_id = self.gen_id(); + block.body.push(Instruction::ray_query_get_intersection( + spirv::Op::RayQueryGetIntersectionWorldToObjectKHR, + transform_type_id, + world_to_object_id, + query_id, + intersection_id, + )); + + let id = self.gen_id(); + let intersection_type_id = self.get_type_id(LookupType::Handle( + self.ir_module.special_types.ray_intersection.unwrap(), + )); + //Note: the arguments must match `generate_ray_intersection_type` layout + block.body.push(Instruction::composite_construct( + intersection_type_id, + id, + &[ + kind_id, + t_id, + instance_custom_index_id, + instance_id, + sbt_record_offset_id, + geometry_index_id, + primitive_index_id, + barycentrics_id, + front_face_id, + object_to_world_id, + world_to_object_id, + ], + )); + id + } +} diff --git a/src/front/type_gen.rs b/src/front/type_gen.rs index 18d9ddd54c..bb734ac69c 100644 --- a/src/front/type_gen.rs +++ b/src/front/type_gen.rs @@ -5,6 +5,7 @@ Type generators. use crate::{arena::Handle, span::Span}; impl crate::Module { + //Note: has to match `struct RayDesc` pub(super) fn generate_ray_desc_type(&mut self) -> Handle { if let Some(handle) = self.special_types.ray_desc { return handle; @@ -95,6 +96,7 @@ impl crate::Module { handle } + //Note: has to match `struct RayIntersection` pub(super) fn generate_ray_intersection_type(&mut self) -> Handle { if let Some(handle) = self.special_types.ray_intersection { return handle; @@ -121,6 +123,38 @@ impl crate::Module { }, Span::UNDEFINED, ); + let ty_barycentrics = self.types.insert( + crate::Type { + name: None, + inner: crate::TypeInner::Vector { + width, + size: crate::VectorSize::Bi, + kind: crate::ScalarKind::Float, + }, + }, + Span::UNDEFINED, + ); + let ty_bool = self.types.insert( + crate::Type { + name: None, + inner: crate::TypeInner::Scalar { + width: crate::BOOL_WIDTH, + kind: crate::ScalarKind::Bool, + }, + }, + Span::UNDEFINED, + ); + let ty_transform = self.types.insert( + crate::Type { + name: None, + inner: crate::TypeInner::Matrix { + columns: crate::VectorSize::Quad, + rows: crate::VectorSize::Tri, + width, + }, + }, + Span::UNDEFINED, + ); let handle = self.types.insert( crate::Type { @@ -139,9 +173,62 @@ impl crate::Module { binding: None, offset: 4, }, - //TODO: the rest + crate::StructMember { + name: Some("instance_custom_index".to_string()), + ty: ty_flag, + binding: None, + offset: 8, + }, + crate::StructMember { + name: Some("instance_id".to_string()), + ty: ty_flag, + binding: None, + offset: 12, + }, + crate::StructMember { + name: Some("sbt_record_offset".to_string()), + ty: ty_flag, + binding: None, + offset: 16, + }, + crate::StructMember { + name: Some("geometry_index".to_string()), + ty: ty_flag, + binding: None, + offset: 20, + }, + crate::StructMember { + name: Some("primitive_index".to_string()), + ty: ty_flag, + binding: None, + offset: 24, + }, + crate::StructMember { + name: Some("barycentrics".to_string()), + ty: ty_barycentrics, + binding: None, + offset: 28, + }, + crate::StructMember { + name: Some("front_face".to_string()), + ty: ty_bool, + binding: None, + offset: 36, + }, + crate::StructMember { + name: Some("object_to_world".to_string()), + ty: ty_transform, + binding: None, + offset: 48, + }, + crate::StructMember { + name: Some("world_to_object".to_string()), + ty: ty_transform, + binding: None, + offset: 112, + }, ], - span: 8, + span: 176, }, }, Span::UNDEFINED, diff --git a/tests/in/ray-query.wgsl b/tests/in/ray-query.wgsl index 7b1053bdcb..5eabf3a2d3 100644 --- a/tests/in/ray-query.wgsl +++ b/tests/in/ray-query.wgsl @@ -29,7 +29,8 @@ struct RayIntersection { primitive_index: u32, barycentrics: vec2, front_face: bool, - //TODO: object ray direction, origin, matrices + object_to_world: mat4x3, + world_to_object: mat4x3, } */ diff --git a/tests/out/spv/ray-query.spvasm b/tests/out/spv/ray-query.spvasm index 455a8bffe9..6bc41ee30f 100644 --- a/tests/out/spv/ray-query.spvasm +++ b/tests/out/spv/ray-query.spvasm @@ -1,14 +1,14 @@ ; SPIR-V ; Version: 1.4 ; Generator: rspirv -; Bound: 52 +; Bound: 63 OpCapability RayQueryKHR OpCapability Shader OpExtension "SPV_KHR_ray_query" %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 -OpEntryPoint GLCompute %26 "main" %18 %20 -OpExecutionMode %26 LocalSize 1 1 1 +OpEntryPoint GLCompute %29 "main" %21 %23 +OpExecutionMode %29 LocalSize 1 1 1 OpMemberDecorate %13 0 Offset 0 OpMemberDecorate %16 0 Offset 0 OpMemberDecorate %16 1 Offset 4 @@ -16,14 +16,27 @@ OpMemberDecorate %16 2 Offset 8 OpMemberDecorate %16 3 Offset 12 OpMemberDecorate %16 4 Offset 16 OpMemberDecorate %16 5 Offset 32 -OpMemberDecorate %17 0 Offset 0 -OpMemberDecorate %17 1 Offset 4 -OpDecorate %18 DescriptorSet 0 -OpDecorate %18 Binding 0 -OpDecorate %20 DescriptorSet 0 -OpDecorate %20 Binding 1 -OpDecorate %21 Block -OpMemberDecorate %21 0 Offset 0 +OpMemberDecorate %20 0 Offset 0 +OpMemberDecorate %20 1 Offset 4 +OpMemberDecorate %20 2 Offset 8 +OpMemberDecorate %20 3 Offset 12 +OpMemberDecorate %20 4 Offset 16 +OpMemberDecorate %20 5 Offset 20 +OpMemberDecorate %20 6 Offset 24 +OpMemberDecorate %20 7 Offset 28 +OpMemberDecorate %20 8 Offset 36 +OpMemberDecorate %20 9 Offset 48 +OpMemberDecorate %20 9 ColMajor +OpMemberDecorate %20 9 MatrixStride 16 +OpMemberDecorate %20 10 Offset 112 +OpMemberDecorate %20 10 ColMajor +OpMemberDecorate %20 10 MatrixStride 16 +OpDecorate %21 DescriptorSet 0 +OpDecorate %21 Binding 0 +OpDecorate %23 DescriptorSet 0 +OpDecorate %23 Binding 1 +OpDecorate %24 Block +OpMemberDecorate %24 0 Offset 0 %2 = OpTypeVoid %4 = OpTypeInt 32 0 %3 = OpConstant %4 4 @@ -39,43 +52,54 @@ OpMemberDecorate %21 0 Offset 0 %14 = OpTypeRayQueryKHR %15 = OpTypeVector %7 3 %16 = OpTypeStruct %4 %4 %7 %7 %15 %15 -%17 = OpTypeStruct %4 %7 -%19 = OpTypePointer UniformConstant %12 -%18 = OpVariable %19 UniformConstant -%21 = OpTypeStruct %13 -%22 = OpTypePointer StorageBuffer %21 -%20 = OpVariable %22 StorageBuffer -%24 = OpTypePointer Function %14 -%27 = OpTypeFunction %2 -%29 = OpTypePointer StorageBuffer %13 -%42 = OpTypeBool -%43 = OpConstant %4 1 -%47 = OpTypePointer StorageBuffer %4 -%26 = OpFunction %2 None %27 -%25 = OpLabel -%23 = OpVariable %24 Function -%28 = OpLoad %12 %18 -%30 = OpAccessChain %29 %20 %11 -OpBranch %31 -%31 = OpLabel -%32 = OpCompositeConstruct %15 %9 %9 %9 -%33 = OpCompositeConstruct %15 %9 %10 %9 -%34 = OpCompositeConstruct %16 %3 %5 %6 %8 %32 %33 -%35 = OpCompositeExtract %4 %34 0 -%36 = OpCompositeExtract %4 %34 1 -%37 = OpCompositeExtract %7 %34 2 -%38 = OpCompositeExtract %7 %34 3 -%39 = OpCompositeExtract %15 %34 4 -%40 = OpCompositeExtract %15 %34 5 -OpRayQueryInitializeKHR %23 %28 %35 %36 %39 %37 %40 %38 -%41 = OpRayQueryProceedKHR %42 %23 -%44 = OpRayQueryGetIntersectionTypeKHR %4 %23 %43 -%45 = OpRayQueryGetIntersectionTKHR %7 %23 %43 -%46 = OpCompositeConstruct %17 %44 %45 -%48 = OpCompositeExtract %4 %46 0 -%49 = OpIEqual %42 %48 %11 -%50 = OpSelect %4 %49 %43 %11 -%51 = OpAccessChain %47 %30 %11 -OpStore %51 %50 +%17 = OpTypeVector %7 2 +%18 = OpTypeBool +%19 = OpTypeMatrix %15 4 +%20 = OpTypeStruct %4 %7 %4 %4 %4 %4 %4 %17 %18 %19 %19 +%22 = OpTypePointer UniformConstant %12 +%21 = OpVariable %22 UniformConstant +%24 = OpTypeStruct %13 +%25 = OpTypePointer StorageBuffer %24 +%23 = OpVariable %25 StorageBuffer +%27 = OpTypePointer Function %14 +%30 = OpTypeFunction %2 +%32 = OpTypePointer StorageBuffer %13 +%45 = OpConstant %4 1 +%58 = OpTypePointer StorageBuffer %4 +%29 = OpFunction %2 None %30 +%28 = OpLabel +%26 = OpVariable %27 Function +%31 = OpLoad %12 %21 +%33 = OpAccessChain %32 %23 %11 +OpBranch %34 +%34 = OpLabel +%35 = OpCompositeConstruct %15 %9 %9 %9 +%36 = OpCompositeConstruct %15 %9 %10 %9 +%37 = OpCompositeConstruct %16 %3 %5 %6 %8 %35 %36 +%38 = OpCompositeExtract %4 %37 0 +%39 = OpCompositeExtract %4 %37 1 +%40 = OpCompositeExtract %7 %37 2 +%41 = OpCompositeExtract %7 %37 3 +%42 = OpCompositeExtract %15 %37 4 +%43 = OpCompositeExtract %15 %37 5 +OpRayQueryInitializeKHR %26 %31 %38 %39 %42 %40 %43 %41 +%44 = OpRayQueryProceedKHR %18 %26 +%46 = OpRayQueryGetIntersectionTypeKHR %4 %26 %45 +%47 = OpRayQueryGetIntersectionInstanceCustomIndexKHR %4 %26 %45 +%48 = OpRayQueryGetIntersectionInstanceIdKHR %4 %26 %45 +%49 = OpRayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR %4 %26 %45 +%50 = OpRayQueryGetIntersectionGeometryIndexKHR %4 %26 %45 +%51 = OpRayQueryGetIntersectionPrimitiveIndexKHR %4 %26 %45 +%52 = OpRayQueryGetIntersectionTKHR %7 %26 %45 +%53 = OpRayQueryGetIntersectionBarycentricsKHR %17 %26 %45 +%54 = OpRayQueryGetIntersectionFrontFaceKHR %18 %26 %45 +%55 = OpRayQueryGetIntersectionObjectToWorldKHR %19 %26 %45 +%56 = OpRayQueryGetIntersectionWorldToObjectKHR %19 %26 %45 +%57 = OpCompositeConstruct %20 %46 %52 %47 %48 %49 %50 %51 %53 %54 %55 %56 +%59 = OpCompositeExtract %4 %57 0 +%60 = OpIEqual %18 %59 %11 +%61 = OpSelect %4 %60 %45 %11 +%62 = OpAccessChain %58 %33 %11 +OpStore %62 %61 OpReturn OpFunctionEnd \ No newline at end of file