From e84c9da5b7352b1dde90f2c54fea4b8db72bb022 Mon Sep 17 00:00:00 2001 From: Jim Blandy Date: Thu, 27 Feb 2025 12:00:52 -0800 Subject: [PATCH] [naga spv-out] Introduce `get_handle_type_id` helper function. Introduce a new helper function on `back::spv::Writer` and `BlockContext` that looks up the SPIR-V type id corresponding to a Naga IR `Handle`. Use it where appropriate. --- naga/src/back/spv/block.rs | 8 ++++---- naga/src/back/spv/image.rs | 2 +- naga/src/back/spv/mod.rs | 4 ++++ naga/src/back/spv/ray.rs | 2 +- naga/src/back/spv/writer.rs | 36 ++++++++++++++++++++---------------- 5 files changed, 30 insertions(+), 22 deletions(-) diff --git a/naga/src/back/spv/block.rs b/naga/src/back/spv/block.rs index d66c579cca..8c6180a263 100644 --- a/naga/src/back/spv/block.rs +++ b/naga/src/back/spv/block.rs @@ -575,7 +575,7 @@ impl BlockContext<'_> { } }; - let binding_type_id = self.get_type_id(LookupType::Handle(binding_type)); + let binding_type_id = self.get_handle_type_id(binding_type); let load_id = self.gen_id(); block.body.push(Instruction::load( @@ -666,7 +666,7 @@ impl BlockContext<'_> { } }; - let binding_type_id = self.get_type_id(LookupType::Handle(binding_type)); + let binding_type_id = self.get_handle_type_id(binding_type); let load_id = self.gen_id(); block.body.push(Instruction::load( @@ -1920,7 +1920,7 @@ impl BlockContext<'_> { .writer .write_ray_query_get_intersection_function(committed, self.ir_module); let ray_intersection = self.ir_module.special_types.ray_intersection.unwrap(); - let intersection_type_id = self.get_type_id(LookupType::Handle(ray_intersection)); + let intersection_type_id = self.get_handle_type_id(ray_intersection); let id = self.gen_id(); block.body.push(Instruction::function_call( intersection_type_id, @@ -3250,7 +3250,7 @@ impl BlockContext<'_> { // need to end it with some kind of return instruction. BlockExit::Return => match self.ir_function.result { Some(ref result) if self.function.entry_point_context.is_none() => { - let type_id = self.get_type_id(LookupType::Handle(result.ty)); + let type_id = self.get_handle_type_id(result.ty); let null_id = self.writer.get_constant_null(type_id); Instruction::return_value(null_id) } diff --git a/naga/src/back/spv/image.rs b/naga/src/back/spv/image.rs index 816dd727a6..e705080850 100644 --- a/naga/src/back/spv/image.rs +++ b/naga/src/back/spv/image.rs @@ -845,7 +845,7 @@ impl BlockContext<'_> { }; // OpTypeSampledImage - let image_type_id = self.get_type_id(LookupType::Handle(image_type)); + let image_type_id = self.get_handle_type_id(image_type); let sampled_image_type_id = self.get_type_id(LookupType::Local(LocalType::SampledImage { image_type_id })); diff --git a/naga/src/back/spv/mod.rs b/naga/src/back/spv/mod.rs index 8faf9e38f6..4aac3e1078 100644 --- a/naga/src/back/spv/mod.rs +++ b/naga/src/back/spv/mod.rs @@ -748,6 +748,10 @@ impl BlockContext<'_> { self.writer.get_type_id(lookup_type) } + fn get_handle_type_id(&mut self, handle: Handle) -> Word { + self.writer.get_handle_type_id(handle) + } + fn get_expression_type_id(&mut self, tr: &TypeResolution) -> Word { self.writer.get_expression_type_id(tr) } diff --git a/naga/src/back/spv/ray.rs b/naga/src/back/spv/ray.rs index 41f50bf61b..65c95a604b 100644 --- a/naga/src/back/spv/ray.rs +++ b/naga/src/back/spv/ray.rs @@ -21,7 +21,7 @@ impl Writer { return func_id; } let ray_intersection = ir_module.special_types.ray_intersection.unwrap(); - let intersection_type_id = self.get_type_id(LookupType::Handle(ray_intersection)); + let intersection_type_id = self.get_handle_type_id(ray_intersection); let intersection_pointer_type_id = self.get_type_id(LookupType::Local(LocalType::Pointer { base: ray_intersection, diff --git a/naga/src/back/spv/writer.rs b/naga/src/back/spv/writer.rs index 45bf6b45dc..77247bc6f1 100644 --- a/naga/src/back/spv/writer.rs +++ b/naga/src/back/spv/writer.rs @@ -222,6 +222,10 @@ impl Writer { } } + pub(super) fn get_handle_type_id(&mut self, handle: Handle) -> Word { + self.get_type_id(LookupType::Handle(handle)) + } + pub(super) fn get_expression_lookup_type(&mut self, tr: &TypeResolution) -> LookupType { match *tr { TypeResolution::Handle(ty_handle) => LookupType::Handle(ty_handle), @@ -639,7 +643,7 @@ impl Writer { let handle_ty = ir_module.types[argument.ty].inner.is_handle(); let argument_type_id = match handle_ty { true => self.get_pointer_type_id(argument.ty, spirv::StorageClass::UniformConstant), - false => self.get_type_id(LookupType::Handle(argument.ty)), + false => self.get_handle_type_id(argument.ty), }; if let Some(ref mut iface) = interface { @@ -671,7 +675,7 @@ impl Writer { let struct_id = self.id_gen.next(); let mut constituent_ids = Vec::with_capacity(members.len()); for member in members { - let type_id = self.get_type_id(LookupType::Handle(member.ty)); + let type_id = self.get_handle_type_id(member.ty); let name = member.name.as_deref(); let binding = member.binding.as_ref().unwrap(); let varying_id = self.write_varying( @@ -716,7 +720,7 @@ impl Writer { handle_id: if handle_ty { let id = self.id_gen.next(); prelude.body.push(Instruction::load( - self.get_type_id(LookupType::Handle(argument.ty)), + self.get_handle_type_id(argument.ty), id, argument_id, None, @@ -738,7 +742,7 @@ impl Writer { if let Some(ref binding) = result.binding { has_point_size |= *binding == crate::Binding::BuiltIn(crate::BuiltIn::PointSize); - let type_id = self.get_type_id(LookupType::Handle(result.ty)); + let type_id = self.get_handle_type_id(result.ty); let varying_id = self.write_varying( ir_module, iface.stage, @@ -757,7 +761,7 @@ impl Writer { ir_module.types[result.ty].inner { for member in members { - let type_id = self.get_type_id(LookupType::Handle(member.ty)); + let type_id = self.get_handle_type_id(member.ty); let name = member.name.as_deref(); let binding = member.binding.as_ref().unwrap(); has_point_size |= @@ -804,7 +808,7 @@ impl Writer { } self.void_type } else { - self.get_type_id(LookupType::Handle(result.ty)) + self.get_handle_type_id(result.ty) } } None => self.void_type, @@ -860,7 +864,7 @@ impl Writer { } _ => { if var.space == crate::AddressSpace::Handle { - let var_type_id = self.get_type_id(LookupType::Handle(var.ty)); + let var_type_id = self.get_handle_type_id(var.ty); let id = self.id_gen.next(); prelude .body @@ -939,7 +943,7 @@ impl Writer { init_word.or_else(|| match ir_module.types[variable.ty].inner { crate::TypeInner::RayQuery { .. } => None, _ => { - let type_id = context.get_type_id(LookupType::Handle(variable.ty)); + let type_id = context.get_handle_type_id(variable.ty); Some(context.writer.write_constant_null(type_id)) } }), @@ -1222,7 +1226,7 @@ impl Writer { Instruction::type_pointer(id, class, base_id) } LocalType::Pointer { base, class } => { - let type_id = self.get_type_id(LookupType::Handle(base)); + let type_id = self.get_handle_type_id(base); Instruction::type_pointer(id, class, type_id) } LocalType::Image(image) => { @@ -1235,7 +1239,7 @@ impl Writer { Instruction::type_sampled_image(id, image_type_id) } LocalType::BindingArray { base, size } => { - let inner_ty = self.get_type_id(LookupType::Handle(base)); + let inner_ty = self.get_handle_type_id(base); let scalar_id = self.get_constant_scalar(crate::Literal::U32(size)); Instruction::type_array(id, inner_ty, scalar_id) } @@ -1289,7 +1293,7 @@ impl Writer { crate::TypeInner::Array { base, size, stride } => { self.decorate(id, Decoration::ArrayStride, &[stride]); - let type_id = self.get_type_id(LookupType::Handle(base)); + let type_id = self.get_handle_type_id(base); match size { crate::ArraySize::Constant(length) => { let length_id = self.get_index_constant(length.get()); @@ -1300,7 +1304,7 @@ impl Writer { } } crate::TypeInner::BindingArray { base, size } => { - let type_id = self.get_type_id(LookupType::Handle(base)); + let type_id = self.get_handle_type_id(base); match size { crate::ArraySize::Constant(length) => { let length_id = self.get_index_constant(length.get()); @@ -1329,7 +1333,7 @@ impl Writer { _ => (), } self.decorate_struct_member(id, index, member, arena)?; - let member_id = self.get_type_id(LookupType::Handle(member.ty)); + let member_id = self.get_handle_type_id(member.ty); member_ids.push(member_id); } if has_runtime_array { @@ -1557,7 +1561,7 @@ impl Writer { self.constant_ids[constant.init] } crate::Expression::ZeroValue(ty) => { - let type_id = self.get_type_id(LookupType::Handle(ty)); + let type_id = self.get_handle_type_id(ty); self.get_constant_null(type_id) } crate::Expression::Compose { ty, ref components } => { @@ -1640,7 +1644,7 @@ impl Writer { // variables in the `Uniform` and `StorageBuffer` address spaces // get wrapped, and we're initializing `WorkGroup` variables. let var_id = self.global_variables[handle].var_id; - let var_type_id = self.get_type_id(LookupType::Handle(var.ty)); + let var_type_id = self.get_handle_type_id(var.ty); let init_word = self.get_constant_null(var_type_id); Instruction::store(var_id, init_word, None) }) @@ -2079,7 +2083,7 @@ impl Writer { } } if should_decorate { - let decorated_id = self.get_type_id(LookupType::Handle(base)); + let decorated_id = self.get_handle_type_id(base); self.decorate(decorated_id, Decoration::Block, &[]); } }