From 19429a1dc9bf5e8ca8f4ad66b20618f1e4e9dbe1 Mon Sep 17 00:00:00 2001 From: Andy Leiserson Date: Fri, 28 Mar 2025 12:43:09 -0700 Subject: [PATCH] [naga] Refactor `BlockContext` type resolution methods Change `resolve_type` and `resolve_type_impl` to return `TypeResolution`s. Add a new method `resolve_type_inner` that returns a `TypeInner` (i.e. what `resolve_type` used to do). --- naga/src/valid/function.rs | 70 ++++++++++++++++++++++---------------- 1 file changed, 41 insertions(+), 29 deletions(-) diff --git a/naga/src/valid/function.rs b/naga/src/valid/function.rs index 84116c6ced..7a25175970 100644 --- a/naga/src/valid/function.rs +++ b/naga/src/valid/function.rs @@ -280,11 +280,11 @@ impl<'a> BlockContext<'a> { &self, handle: Handle, valid_expressions: &HandleSet, - ) -> Result<&crate::TypeInner, WithSpan> { + ) -> Result<&TypeResolution, WithSpan> { if !valid_expressions.contains(handle) { Err(ExpressionError::NotInScope.with_span_handle(handle, self.expressions)) } else { - Ok(self.info[handle].ty.inner_with(self.types)) + Ok(&self.info[handle].ty) } } @@ -292,11 +292,20 @@ impl<'a> BlockContext<'a> { &self, handle: Handle, valid_expressions: &HandleSet, - ) -> Result<&crate::TypeInner, WithSpan> { + ) -> Result<&TypeResolution, WithSpan> { self.resolve_type_impl(handle, valid_expressions) .map_err_inner(|source| FunctionError::Expression { handle, source }.with_span()) } + fn resolve_type_inner( + &self, + handle: Handle, + valid_expressions: &HandleSet, + ) -> Result<&crate::TypeInner, WithSpan> { + self.resolve_type(handle, valid_expressions) + .map(|tr| tr.inner_with(self.types)) + } + fn resolve_pointer_type(&self, handle: Handle) -> &crate::TypeInner { self.info[handle].ty.inner_with(self.types) } @@ -330,7 +339,7 @@ impl super::Validator { .with_span_handle(expr, context.expressions) })?; let arg_inner = &context.types[arg.ty].inner; - if !ty.non_struct_equivalent(arg_inner, context.types) { + if !ty.inner_with(context.types).non_struct_equivalent(arg_inner, context.types) { return Err(CallError::ArgumentType { index, required: arg.ty, @@ -393,7 +402,7 @@ impl super::Validator { context: &BlockContext, ) -> Result<(), WithSpan> { // The `pointer` operand must be a pointer to an atomic value. - let pointer_inner = context.resolve_type(pointer, &self.valid_expression_set)?; + let pointer_inner = context.resolve_type_inner(pointer, &self.valid_expression_set)?; let crate::TypeInner::Pointer { base: pointer_base, space: pointer_space, @@ -415,7 +424,7 @@ impl super::Validator { }; // The `value` operand must be a scalar of the same type as the atomic. - let value_inner = context.resolve_type(value, &self.valid_expression_set)?; + let value_inner = context.resolve_type_inner(value, &self.valid_expression_set)?; let crate::TypeInner::Scalar(value_scalar) = *value_inner else { log::error!("Atomic operand type {:?}", *value_inner); return Err(AtomicError::InvalidOperand(value) @@ -543,7 +552,7 @@ impl super::Validator { // The comparison value must be a scalar of the same type as the // atomic we're operating on. let compare_inner = - context.resolve_type(compare, &self.valid_expression_set)?; + context.resolve_type_inner(compare, &self.valid_expression_set)?; if !compare_inner.non_struct_equivalent(value_inner, context.types) { log::error!( "Atomic exchange comparison has a different type from the value" @@ -620,7 +629,7 @@ impl super::Validator { result: Handle, context: &BlockContext, ) -> Result<(), WithSpan> { - let argument_inner = context.resolve_type(argument, &self.valid_expression_set)?; + let argument_inner = context.resolve_type_inner(argument, &self.valid_expression_set)?; let (is_scalar, scalar) = match *argument_inner { crate::TypeInner::Scalar(scalar) => (true, scalar), @@ -695,7 +704,7 @@ impl super::Validator { | crate::GatherMode::ShuffleDown(index) | crate::GatherMode::ShuffleUp(index) | crate::GatherMode::ShuffleXor(index) => { - let index_ty = context.resolve_type(index, &self.valid_expression_set)?; + let index_ty = context.resolve_type_inner(index, &self.valid_expression_set)?; match *index_ty { crate::TypeInner::Scalar(crate::Scalar::U32) => {} _ => { @@ -710,7 +719,7 @@ impl super::Validator { } } } - let argument_inner = context.resolve_type(argument, &self.valid_expression_set)?; + let argument_inner = context.resolve_type_inner(argument, &self.valid_expression_set)?; if !matches!(*argument_inner, crate::TypeInner::Scalar ( scalar, .. ) | crate::TypeInner::Vector { scalar, .. } if matches!(scalar.kind, crate::ScalarKind::Uint | crate::ScalarKind::Sint | crate::ScalarKind::Float) @@ -802,7 +811,7 @@ impl super::Validator { ref accept, ref reject, } => { - match *context.resolve_type(condition, &self.valid_expression_set)? { + match *context.resolve_type_inner(condition, &self.valid_expression_set)? { Ti::Scalar(crate::Scalar { kind: crate::ScalarKind::Bool, width: _, @@ -820,7 +829,7 @@ impl super::Validator { ref cases, } => { let uint = match context - .resolve_type(selector, &self.valid_expression_set)? + .resolve_type_inner(selector, &self.valid_expression_set)? .scalar_kind() { Some(crate::ScalarKind::Uint) => true, @@ -917,7 +926,7 @@ impl super::Validator { .stages; if let Some(condition) = break_if { - match *context.resolve_type(condition, &self.valid_expression_set)? { + match *context.resolve_type_inner(condition, &self.valid_expression_set)? { Ti::Scalar(crate::Scalar { kind: crate::ScalarKind::Bool, width: _, @@ -961,7 +970,7 @@ impl super::Validator { let okay = match (value_ty, expected_ty) { (None, None) => true, (Some(value_inner), Some(expected_inner)) => { - value_inner.non_struct_equivalent(expected_inner, context.types) + value_inner.inner_with(context.types).non_struct_equivalent(expected_inner, context.types) } (_, _) => false, }; @@ -1027,7 +1036,7 @@ impl super::Validator { } } - let value_ty = context.resolve_type(value, &self.valid_expression_set)?; + let value_ty = context.resolve_type_inner(value, &self.valid_expression_set)?; match *value_ty { Ti::Image { .. } | Ti::Sampler { .. } => { return Err(FunctionError::InvalidStoreTexture { @@ -1145,7 +1154,7 @@ impl super::Validator { // The `coordinate` operand must be a vector of the appropriate size. if context - .resolve_type(coordinate, &self.valid_expression_set)? + .resolve_type_inner(coordinate, &self.valid_expression_set)? .image_storage_coordinates() .is_none_or(|coord_dim| coord_dim != dim) { @@ -1167,7 +1176,7 @@ impl super::Validator { // If present, `array_index` must be a scalar integer type. if let Some(expr) = array_index { if !matches!( - *context.resolve_type(expr, &self.valid_expression_set)?, + *context.resolve_type_inner(expr, &self.valid_expression_set)?, Ti::Scalar(crate::Scalar { kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint, width: _, @@ -1188,7 +1197,7 @@ impl super::Validator { // The value we're writing had better match the scalar type // for `image`'s format. let actual_value_ty = - context.resolve_type(value, &self.valid_expression_set)?; + context.resolve_type_inner(value, &self.valid_expression_set)?; if actual_value_ty != &value_ty { return Err(FunctionError::InvalidStoreValue { actual: value, @@ -1273,7 +1282,7 @@ impl super::Validator { dim, } => { match context - .resolve_type(coordinate, &self.valid_expression_set)? + .resolve_type_inner(coordinate, &self.valid_expression_set)? .image_storage_coordinates() { Some(coord_dim) if coord_dim == dim => {} @@ -1293,7 +1302,9 @@ impl super::Validator { .with_span_handle(coordinate, context.expressions)); } if let Some(expr) = array_index { - match *context.resolve_type(expr, &self.valid_expression_set)? { + match *context + .resolve_type_inner(expr, &self.valid_expression_set)? + { Ti::Scalar(crate::Scalar { kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint, width: _, @@ -1404,7 +1415,7 @@ impl super::Validator { } }; - if *context.resolve_type(value, &self.valid_expression_set)? != value_ty { + if *context.resolve_type_inner(value, &self.valid_expression_set)? != value_ty { return Err(FunctionError::InvalidImageAtomicValue(value) .with_span_handle(value, context.expressions)); } @@ -1412,7 +1423,7 @@ impl super::Validator { S::WorkGroupUniformLoad { pointer, result } => { stages &= super::ShaderStages::COMPUTE; let pointer_inner = - context.resolve_type(pointer, &self.valid_expression_set)?; + context.resolve_type_inner(pointer, &self.valid_expression_set)?; match *pointer_inner { Ti::Pointer { space: AddressSpace::WorkGroup, @@ -1468,9 +1479,10 @@ impl super::Validator { acceleration_structure, descriptor, } => { - match *context - .resolve_type(acceleration_structure, &self.valid_expression_set)? - { + match *context.resolve_type_inner( + acceleration_structure, + &self.valid_expression_set, + )? { Ti::AccelerationStructure { vertex_return } => { if (!vertex_return) && rq_vertex_return { return Err(FunctionError::MissingAccelerationStructureVertexReturn(acceleration_structure, query).with_span_static(span, "invalid acceleration structure")); @@ -1483,8 +1495,8 @@ impl super::Validator { .with_span_static(span, "invalid acceleration structure")) } } - let desc_ty_given = - context.resolve_type(descriptor, &self.valid_expression_set)?; + let desc_ty_given = context + .resolve_type_inner(descriptor, &self.valid_expression_set)?; let desc_ty_expected = context .special_types .ray_desc @@ -1498,7 +1510,7 @@ impl super::Validator { self.emit_expression(result, context)?; } crate::RayQueryFunction::GenerateIntersection { hit_t } => { - match *context.resolve_type(hit_t, &self.valid_expression_set)? { + match *context.resolve_type_inner(hit_t, &self.valid_expression_set)? { Ti::Scalar(crate::Scalar { kind: crate::ScalarKind::Float, width: _, @@ -1534,7 +1546,7 @@ impl super::Validator { } if let Some(predicate) = predicate { let predicate_inner = - context.resolve_type(predicate, &self.valid_expression_set)?; + context.resolve_type_inner(predicate, &self.valid_expression_set)?; if !matches!( *predicate_inner, crate::TypeInner::Scalar(crate::Scalar::BOOL,)