diff --git a/naga/src/arena/handle_set.rs b/naga/src/arena/handle_set.rs index 47c2937a23..f1670dcf4f 100644 --- a/naga/src/arena/handle_set.rs +++ b/naga/src/arena/handle_set.rs @@ -3,6 +3,7 @@ use crate::arena::{Arena, Handle, UniqueArena}; /// A set of `Handle` values. +#[derive(Debug)] pub struct HandleSet { /// Bound on indexes of handles stored in this set. len: usize, @@ -15,6 +16,16 @@ pub struct HandleSet { } impl HandleSet { + /// Return a new, empty `HandleSet`. + pub fn new() -> Self { + Self { + len: 0, + members: bit_set::BitSet::new(), + as_keys: std::marker::PhantomData, + } + } + + /// Return a new, empty `HandleSet`, sized to hold handles from `arena`. pub fn for_arena(arena: &impl ArenaType) -> Self { let len = arena.len(); Self { @@ -24,6 +35,17 @@ impl HandleSet { } } + /// Remove all members from `self`. + pub fn clear(&mut self) { + self.members.clear(); + } + + /// Remove all members from `self`, and reserve space to hold handles from `arena`. + pub fn clear_for_arena(&mut self, arena: &impl ArenaType) { + self.members.clear(); + self.members.reserve_len(arena.len()); + } + /// Return an iterator over all handles that could be made members /// of this set. pub fn all_possible(&self) -> impl Iterator> { @@ -37,6 +59,13 @@ impl HandleSet { self.members.insert(handle.index()) } + /// Remove `handle` from the set. + /// + /// Returns `true` if `handle` was present in the set. + pub fn remove(&mut self, handle: Handle) -> bool { + self.members.remove(handle.index()) + } + /// Add handles from `iter` to the set. pub fn insert_iter(&mut self, iter: impl IntoIterator>) { for handle in iter { diff --git a/naga/src/valid/function.rs b/naga/src/valid/function.rs index d8c4791285..b2f9c8c47f 100644 --- a/naga/src/valid/function.rs +++ b/naga/src/valid/function.rs @@ -1,5 +1,5 @@ -use crate::arena::Handle; use crate::arena::{Arena, UniqueArena}; +use crate::arena::{Handle, HandleSet}; use super::validate_atomic_compare_exchange_struct; @@ -10,8 +10,6 @@ use super::{ use crate::span::WithSpan; use crate::span::{AddSpan as _, MapErrWithSpan as _}; -use bit_set::BitSet; - #[derive(Clone, Debug, thiserror::Error)] #[cfg_attr(test, derive(PartialEq))] pub enum CallError { @@ -257,9 +255,9 @@ impl<'a> BlockContext<'a> { fn resolve_type_impl( &self, handle: Handle, - valid_expressions: &BitSet, + valid_expressions: &HandleSet, ) -> Result<&crate::TypeInner, WithSpan> { - if !valid_expressions.contains(handle.index()) { + if !valid_expressions.contains(handle) { Err(ExpressionError::NotInScope.with_span_handle(handle, self.expressions)) } else { Ok(self.info[handle].ty.inner_with(self.types)) @@ -269,7 +267,7 @@ impl<'a> BlockContext<'a> { fn resolve_type( &self, handle: Handle, - valid_expressions: &BitSet, + valid_expressions: &HandleSet, ) -> Result<&crate::TypeInner, WithSpan> { self.resolve_type_impl(handle, valid_expressions) .map_err_inner(|source| FunctionError::Expression { handle, source }.with_span()) @@ -315,7 +313,7 @@ impl super::Validator { } if let Some(expr) = result { - if self.valid_expression_set.insert(expr.index()) { + if self.valid_expression_set.insert(expr) { self.valid_expression_list.push(expr); } else { return Err(CallError::ResultAlreadyInScope(expr) @@ -348,7 +346,7 @@ impl super::Validator { handle: Handle, context: &BlockContext, ) -> Result<(), WithSpan> { - if self.valid_expression_set.insert(handle.index()) { + if self.valid_expression_set.insert(handle) { self.valid_expression_list.push(handle); Ok(()) } else { @@ -864,7 +862,7 @@ impl super::Validator { } for handle in self.valid_expression_list.drain(base_expression_count..) { - self.valid_expression_set.remove(handle.index()); + self.valid_expression_set.remove(handle); } } S::Break => { @@ -1321,7 +1319,7 @@ impl super::Validator { let base_expression_count = self.valid_expression_list.len(); let info = self.validate_block_impl(statements, context)?; for handle in self.valid_expression_list.drain(base_expression_count..) { - self.valid_expression_set.remove(handle.index()); + self.valid_expression_set.remove(handle); } Ok(info) } @@ -1429,12 +1427,12 @@ impl super::Validator { } } - self.valid_expression_set.clear(); + self.valid_expression_set.clear_for_arena(&fun.expressions); self.valid_expression_list.clear(); self.needs_visit.clear(); for (handle, expr) in fun.expressions.iter() { if expr.needs_pre_emit() { - self.valid_expression_set.insert(handle.index()); + self.valid_expression_set.insert(handle); } if self.flags.contains(super::ValidationFlags::EXPRESSIONS) { // Mark expressions that need to be visited by a particular kind of diff --git a/naga/src/valid/mod.rs b/naga/src/valid/mod.rs index ce1c1eab35..932a6fdb1e 100644 --- a/naga/src/valid/mod.rs +++ b/naga/src/valid/mod.rs @@ -11,7 +11,7 @@ mod interface; mod r#type; use crate::{ - arena::Handle, + arena::{Handle, HandleSet}, proc::{ExpressionKindTracker, LayoutError, Layouter, TypeResolution}, FastHashSet, }; @@ -259,7 +259,7 @@ pub struct Validator { #[allow(dead_code)] switch_values: FastHashSet, valid_expression_list: Vec>, - valid_expression_set: BitSet, + valid_expression_set: HandleSet, override_ids: FastHashSet, allow_overrides: bool, @@ -448,7 +448,7 @@ impl Validator { ep_resource_bindings: FastHashSet::default(), switch_values: FastHashSet::default(), valid_expression_list: Vec::new(), - valid_expression_set: BitSet::new(), + valid_expression_set: HandleSet::new(), override_ids: FastHashSet::default(), allow_overrides: true, needs_visit: BitSet::new(),