[naga spv-out] Introduce get_handle_type_id helper function.

Introduce a new helper function on `back::spv::Writer` and
`BlockContext` that looks up the SPIR-V type id corresponding to a
Naga IR `Handle<Type>`. Use it where appropriate.
This commit is contained in:
Jim Blandy
2025-02-27 12:00:52 -08:00
parent 77265d38cd
commit e84c9da5b7
5 changed files with 30 additions and 22 deletions

View File

@@ -575,7 +575,7 @@ impl BlockContext<'_> {
}
};
let binding_type_id = self.get_type_id(LookupType::Handle(binding_type));
let binding_type_id = self.get_handle_type_id(binding_type);
let load_id = self.gen_id();
block.body.push(Instruction::load(
@@ -666,7 +666,7 @@ impl BlockContext<'_> {
}
};
let binding_type_id = self.get_type_id(LookupType::Handle(binding_type));
let binding_type_id = self.get_handle_type_id(binding_type);
let load_id = self.gen_id();
block.body.push(Instruction::load(
@@ -1920,7 +1920,7 @@ impl BlockContext<'_> {
.writer
.write_ray_query_get_intersection_function(committed, self.ir_module);
let ray_intersection = self.ir_module.special_types.ray_intersection.unwrap();
let intersection_type_id = self.get_type_id(LookupType::Handle(ray_intersection));
let intersection_type_id = self.get_handle_type_id(ray_intersection);
let id = self.gen_id();
block.body.push(Instruction::function_call(
intersection_type_id,
@@ -3250,7 +3250,7 @@ impl BlockContext<'_> {
// need to end it with some kind of return instruction.
BlockExit::Return => match self.ir_function.result {
Some(ref result) if self.function.entry_point_context.is_none() => {
let type_id = self.get_type_id(LookupType::Handle(result.ty));
let type_id = self.get_handle_type_id(result.ty);
let null_id = self.writer.get_constant_null(type_id);
Instruction::return_value(null_id)
}

View File

@@ -845,7 +845,7 @@ impl BlockContext<'_> {
};
// OpTypeSampledImage
let image_type_id = self.get_type_id(LookupType::Handle(image_type));
let image_type_id = self.get_handle_type_id(image_type);
let sampled_image_type_id =
self.get_type_id(LookupType::Local(LocalType::SampledImage { image_type_id }));

View File

@@ -748,6 +748,10 @@ impl BlockContext<'_> {
self.writer.get_type_id(lookup_type)
}
fn get_handle_type_id(&mut self, handle: Handle<crate::Type>) -> Word {
self.writer.get_handle_type_id(handle)
}
fn get_expression_type_id(&mut self, tr: &TypeResolution) -> Word {
self.writer.get_expression_type_id(tr)
}

View File

@@ -21,7 +21,7 @@ impl Writer {
return func_id;
}
let ray_intersection = ir_module.special_types.ray_intersection.unwrap();
let intersection_type_id = self.get_type_id(LookupType::Handle(ray_intersection));
let intersection_type_id = self.get_handle_type_id(ray_intersection);
let intersection_pointer_type_id =
self.get_type_id(LookupType::Local(LocalType::Pointer {
base: ray_intersection,

View File

@@ -222,6 +222,10 @@ impl Writer {
}
}
pub(super) fn get_handle_type_id(&mut self, handle: Handle<crate::Type>) -> Word {
self.get_type_id(LookupType::Handle(handle))
}
pub(super) fn get_expression_lookup_type(&mut self, tr: &TypeResolution) -> LookupType {
match *tr {
TypeResolution::Handle(ty_handle) => LookupType::Handle(ty_handle),
@@ -639,7 +643,7 @@ impl Writer {
let handle_ty = ir_module.types[argument.ty].inner.is_handle();
let argument_type_id = match handle_ty {
true => self.get_pointer_type_id(argument.ty, spirv::StorageClass::UniformConstant),
false => self.get_type_id(LookupType::Handle(argument.ty)),
false => self.get_handle_type_id(argument.ty),
};
if let Some(ref mut iface) = interface {
@@ -671,7 +675,7 @@ impl Writer {
let struct_id = self.id_gen.next();
let mut constituent_ids = Vec::with_capacity(members.len());
for member in members {
let type_id = self.get_type_id(LookupType::Handle(member.ty));
let type_id = self.get_handle_type_id(member.ty);
let name = member.name.as_deref();
let binding = member.binding.as_ref().unwrap();
let varying_id = self.write_varying(
@@ -716,7 +720,7 @@ impl Writer {
handle_id: if handle_ty {
let id = self.id_gen.next();
prelude.body.push(Instruction::load(
self.get_type_id(LookupType::Handle(argument.ty)),
self.get_handle_type_id(argument.ty),
id,
argument_id,
None,
@@ -738,7 +742,7 @@ impl Writer {
if let Some(ref binding) = result.binding {
has_point_size |=
*binding == crate::Binding::BuiltIn(crate::BuiltIn::PointSize);
let type_id = self.get_type_id(LookupType::Handle(result.ty));
let type_id = self.get_handle_type_id(result.ty);
let varying_id = self.write_varying(
ir_module,
iface.stage,
@@ -757,7 +761,7 @@ impl Writer {
ir_module.types[result.ty].inner
{
for member in members {
let type_id = self.get_type_id(LookupType::Handle(member.ty));
let type_id = self.get_handle_type_id(member.ty);
let name = member.name.as_deref();
let binding = member.binding.as_ref().unwrap();
has_point_size |=
@@ -804,7 +808,7 @@ impl Writer {
}
self.void_type
} else {
self.get_type_id(LookupType::Handle(result.ty))
self.get_handle_type_id(result.ty)
}
}
None => self.void_type,
@@ -860,7 +864,7 @@ impl Writer {
}
_ => {
if var.space == crate::AddressSpace::Handle {
let var_type_id = self.get_type_id(LookupType::Handle(var.ty));
let var_type_id = self.get_handle_type_id(var.ty);
let id = self.id_gen.next();
prelude
.body
@@ -939,7 +943,7 @@ impl Writer {
init_word.or_else(|| match ir_module.types[variable.ty].inner {
crate::TypeInner::RayQuery { .. } => None,
_ => {
let type_id = context.get_type_id(LookupType::Handle(variable.ty));
let type_id = context.get_handle_type_id(variable.ty);
Some(context.writer.write_constant_null(type_id))
}
}),
@@ -1222,7 +1226,7 @@ impl Writer {
Instruction::type_pointer(id, class, base_id)
}
LocalType::Pointer { base, class } => {
let type_id = self.get_type_id(LookupType::Handle(base));
let type_id = self.get_handle_type_id(base);
Instruction::type_pointer(id, class, type_id)
}
LocalType::Image(image) => {
@@ -1235,7 +1239,7 @@ impl Writer {
Instruction::type_sampled_image(id, image_type_id)
}
LocalType::BindingArray { base, size } => {
let inner_ty = self.get_type_id(LookupType::Handle(base));
let inner_ty = self.get_handle_type_id(base);
let scalar_id = self.get_constant_scalar(crate::Literal::U32(size));
Instruction::type_array(id, inner_ty, scalar_id)
}
@@ -1289,7 +1293,7 @@ impl Writer {
crate::TypeInner::Array { base, size, stride } => {
self.decorate(id, Decoration::ArrayStride, &[stride]);
let type_id = self.get_type_id(LookupType::Handle(base));
let type_id = self.get_handle_type_id(base);
match size {
crate::ArraySize::Constant(length) => {
let length_id = self.get_index_constant(length.get());
@@ -1300,7 +1304,7 @@ impl Writer {
}
}
crate::TypeInner::BindingArray { base, size } => {
let type_id = self.get_type_id(LookupType::Handle(base));
let type_id = self.get_handle_type_id(base);
match size {
crate::ArraySize::Constant(length) => {
let length_id = self.get_index_constant(length.get());
@@ -1329,7 +1333,7 @@ impl Writer {
_ => (),
}
self.decorate_struct_member(id, index, member, arena)?;
let member_id = self.get_type_id(LookupType::Handle(member.ty));
let member_id = self.get_handle_type_id(member.ty);
member_ids.push(member_id);
}
if has_runtime_array {
@@ -1557,7 +1561,7 @@ impl Writer {
self.constant_ids[constant.init]
}
crate::Expression::ZeroValue(ty) => {
let type_id = self.get_type_id(LookupType::Handle(ty));
let type_id = self.get_handle_type_id(ty);
self.get_constant_null(type_id)
}
crate::Expression::Compose { ty, ref components } => {
@@ -1640,7 +1644,7 @@ impl Writer {
// variables in the `Uniform` and `StorageBuffer` address spaces
// get wrapped, and we're initializing `WorkGroup` variables.
let var_id = self.global_variables[handle].var_id;
let var_type_id = self.get_type_id(LookupType::Handle(var.ty));
let var_type_id = self.get_handle_type_id(var.ty);
let init_word = self.get_constant_null(var_type_id);
Instruction::store(var_id, init_word, None)
})
@@ -2079,7 +2083,7 @@ impl Writer {
}
}
if should_decorate {
let decorated_id = self.get_type_id(LookupType::Handle(base));
let decorated_id = self.get_handle_type_id(base);
self.decorate(decorated_id, Decoration::Block, &[]);
}
}