fix: improved handling of start states

This commit is contained in:
shreyas-londhe
2025-05-10 13:07:06 +05:30
parent b54cd298ee
commit 22163b2af0
2 changed files with 120 additions and 38 deletions

View File

@@ -19,13 +19,13 @@ impl NFAGraph {
let re = PikeVM::new(pattern).map_err(|e| NFAError::RegexCompilation(e.to_string()))?;
let thompson_nfa = re.get_nfa();
let state_len = thompson_nfa.states().len() - 2;
let state_len = thompson_nfa.states().len() - thompson_nfa.start_anchored().as_usize();
let mut graph = Self::default();
graph.regex = pattern.to_string();
graph.initialize_nodes(state_len)?;
graph.process_all_states(&thompson_nfa)?;
graph.set_start_states(&thompson_nfa);
graph.start_states.insert(0);
graph.remove_epsilon_transitions()?;
graph.verify()?;
@@ -50,27 +50,48 @@ impl NFAGraph {
/// Processes all states from the Thompson NFA
fn process_all_states(&mut self, nfa: &NFA) -> NFAResult<()> {
for state_idx in 0..self.nodes.len() {
let state_id =
StateID::new(state_idx + 2).map_err(|e| NFAError::InvalidStateId(e.to_string()))?;
let state_id = StateID::new(state_idx + nfa.start_anchored().as_usize())
.map_err(|e| NFAError::InvalidStateId(e.to_string()))?;
match nfa.state(state_id) {
State::Match { .. } => {
self.accept_states.insert(state_idx);
}
State::ByteRange { trans } => {
self.add_byte_range_transition(state_idx, trans)?;
self.add_byte_range_transition(
nfa.start_anchored().as_usize(),
state_idx,
trans,
)?;
}
State::Sparse(sparse) => {
self.add_sparse_transitions(state_idx, &sparse.transitions)?;
self.add_sparse_transitions(
nfa.start_anchored().as_usize(),
state_idx,
&sparse.transitions,
)?;
}
State::Dense(dense) => {
self.add_dense_transitions(state_idx, &dense.transitions)?;
self.add_dense_transitions(
nfa.start_anchored().as_usize(),
state_idx,
&dense.transitions,
)?;
}
State::Union { alternates } => {
self.add_union_transitions(state_idx, alternates)?;
self.add_union_transitions(
nfa.start_anchored().as_usize(),
state_idx,
alternates,
)?;
}
State::BinaryUnion { alt1, alt2 } => {
self.add_binary_union_transitions(state_idx, alt1, alt2)?;
self.add_binary_union_transitions(
nfa.start_anchored().as_usize(),
state_idx,
alt1,
alt2,
)?;
}
State::Capture {
next,
@@ -78,11 +99,17 @@ impl NFAGraph {
slot,
..
} => {
self.add_capture_transition(state_idx, next, group_index, slot)?;
self.add_capture_transition(
nfa.start_anchored().as_usize(),
state_idx,
next,
group_index,
slot,
)?;
self.num_capture_groups = self.num_capture_groups.max(group_index.as_usize());
}
State::Look { next, .. } => {
self.add_look_transition(state_idx, next)?;
self.add_look_transition(nfa.start_anchored().as_usize(), state_idx, next)?;
}
State::Fail => {} // No transitions needed
}
@@ -91,13 +118,18 @@ impl NFAGraph {
}
/// Adds a byte range transition to the graph
fn add_byte_range_transition(&mut self, state_id: usize, trans: &Transition) -> NFAResult<()> {
fn add_byte_range_transition(
&mut self,
anchored_state_id: usize,
state_id: usize,
trans: &Transition,
) -> NFAResult<()> {
for byte in trans.start..=trans.end {
self.nodes[state_id]
.byte_transitions
.entry(byte)
.or_insert_with(BTreeSet::new)
.insert(trans.next.as_usize() - 2);
.insert(trans.next.as_usize() - anchored_state_id);
}
Ok(())
}
@@ -105,66 +137,84 @@ impl NFAGraph {
/// Adds transitions from a sparse transition set
fn add_sparse_transitions(
&mut self,
anchored_state_id: usize,
state_id: usize,
transitions: &[Transition],
) -> NFAResult<()> {
for trans in transitions {
self.add_byte_range_transition(state_id, trans)?;
self.add_byte_range_transition(anchored_state_id, state_id, trans)?;
}
Ok(())
}
/// Adds transitions from a dense transition table
fn add_dense_transitions(&mut self, state_id: usize, transitions: &[StateID]) -> NFAResult<()> {
fn add_dense_transitions(
&mut self,
anchored_state_id: usize,
state_id: usize,
transitions: &[StateID],
) -> NFAResult<()> {
for (byte, &next) in transitions.iter().enumerate() {
if next != StateID::ZERO {
self.nodes[state_id]
.byte_transitions
.entry(byte as u8)
.or_insert_with(BTreeSet::new)
.insert(next.as_usize() - 2);
.insert(next.as_usize() - anchored_state_id);
}
}
Ok(())
}
/// Adds epsilon transitions for a union state
fn add_union_transitions(&mut self, state_id: usize, alternates: &[StateID]) -> NFAResult<()> {
self.nodes[state_id]
.epsilon_transitions
.extend(alternates.iter().map(|id| id.as_usize() - 2));
fn add_union_transitions(
&mut self,
anchored_state_id: usize,
state_id: usize,
alternates: &[StateID],
) -> NFAResult<()> {
self.nodes[state_id].epsilon_transitions.extend(
alternates
.iter()
.map(|id| id.as_usize() - anchored_state_id),
);
Ok(())
}
/// Adds epsilon transitions for a binary union state
fn add_binary_union_transitions(
&mut self,
anchored_state_id: usize,
state_id: usize,
alt1: &StateID,
alt2: &StateID,
) -> NFAResult<()> {
let node = &mut self.nodes[state_id];
node.epsilon_transitions.insert(alt1.as_usize() - 2);
node.epsilon_transitions.insert(alt2.as_usize() - 2);
node.epsilon_transitions
.insert(alt1.as_usize() - anchored_state_id);
node.epsilon_transitions
.insert(alt2.as_usize() - anchored_state_id);
Ok(())
}
/// Adds an epsilon transition with capture group information
fn add_capture_transition(
&mut self,
anchored_state_id: usize,
state_id: usize,
next: &StateID,
group_index: &SmallIndex,
slot: &SmallIndex,
) -> NFAResult<()> {
let node = &mut self.nodes[state_id];
node.epsilon_transitions.insert(next.as_usize() - 2);
node.epsilon_transitions
.insert(next.as_usize() - anchored_state_id);
let group_idx = group_index.as_usize();
if group_idx > 0 {
let is_start = slot.as_usize() % 2 == 0;
node.capture_groups
.entry(next.as_usize() - 2)
.entry(next.as_usize() - anchored_state_id)
.or_insert_with(BTreeSet::new)
.insert((group_idx, is_start));
}
@@ -172,19 +222,18 @@ impl NFAGraph {
}
/// Adds an epsilon transition for a look-around state
fn add_look_transition(&mut self, state_id: usize, next: &StateID) -> NFAResult<()> {
fn add_look_transition(
&mut self,
anchored_state_id: usize,
state_id: usize,
next: &StateID,
) -> NFAResult<()> {
self.nodes[state_id]
.epsilon_transitions
.insert(next.as_usize() - 2);
.insert(next.as_usize() - anchored_state_id);
Ok(())
}
/// Sets the start states for the NFA
fn set_start_states(&mut self, nfa: &NFA) {
self.start_states
.insert(nfa.start_anchored().as_usize() - 2);
}
pub fn pretty_print(&self) {
println!("\n=== NFA Graph ===");
println!("Regex: {}", self.regex);

View File

@@ -98,15 +98,48 @@ impl NFAGraph {
}
}
// Handle start states - only make byte transition states reachable via epsilon into start states
for &start in &self.start_states {
new_start_states.insert(start);
// Handle start states
// Preserve original start states to iterate over them
let original_start_states_snapshot: BTreeSet<usize> =
self.start_states.iter().copied().collect();
new_start_states.clear();
for &r_state in &closures[start].states {
if has_byte_transitions[r_state] {
new_start_states.insert(r_state);
for &orig_start in &original_start_states_snapshot {
new_start_states.insert(orig_start); // The original start state is always kept
// Check if the closure of this original start state contains any START captures.
// If so, we don't want to create alternative start points from within this closure,
// as that might allow bypassing these essential start captures.
let mut has_start_captures_in_orig_closure = false;
if let Some(orig_closure) = closures.get(orig_start) {
for &(_, (_group_id, is_start_event)) in &orig_closure.captures {
if is_start_event {
has_start_captures_in_orig_closure = true;
break;
}
}
}
if !has_start_captures_in_orig_closure {
// If no start captures in orig_start's closure, it's safe to add
// other states from its closure that have byte transitions as new start states.
if let Some(orig_closure) = closures.get(orig_start) {
for &r_state in &orig_closure.states {
if r_state == orig_start {
continue;
}
// Check if r_state (a state reachable via epsilon from orig_start)
// itself is the source of a byte transition.
// The has_byte_transitions vec was populated based on nodes[r_state].byte_transitions
if r_state < has_byte_transitions.len() && has_byte_transitions[r_state] {
new_start_states.insert(r_state);
}
}
}
}
// If has_start_captures_in_orig_closure is true, we *only* keep orig_start.
// This forces paths through orig_start, ensuring its transitions (which will
// have correctly accumulated these start captures) are used.
}
// Apply changes