mirror of
https://github.com/zkemail/zk-regex.git
synced 2026-01-10 06:07:58 -05:00
fix: improved handling of start states
This commit is contained in:
@@ -19,13 +19,13 @@ impl NFAGraph {
|
||||
let re = PikeVM::new(pattern).map_err(|e| NFAError::RegexCompilation(e.to_string()))?;
|
||||
let thompson_nfa = re.get_nfa();
|
||||
|
||||
let state_len = thompson_nfa.states().len() - 2;
|
||||
let state_len = thompson_nfa.states().len() - thompson_nfa.start_anchored().as_usize();
|
||||
|
||||
let mut graph = Self::default();
|
||||
graph.regex = pattern.to_string();
|
||||
graph.initialize_nodes(state_len)?;
|
||||
graph.process_all_states(&thompson_nfa)?;
|
||||
graph.set_start_states(&thompson_nfa);
|
||||
graph.start_states.insert(0);
|
||||
graph.remove_epsilon_transitions()?;
|
||||
|
||||
graph.verify()?;
|
||||
@@ -50,27 +50,48 @@ impl NFAGraph {
|
||||
/// Processes all states from the Thompson NFA
|
||||
fn process_all_states(&mut self, nfa: &NFA) -> NFAResult<()> {
|
||||
for state_idx in 0..self.nodes.len() {
|
||||
let state_id =
|
||||
StateID::new(state_idx + 2).map_err(|e| NFAError::InvalidStateId(e.to_string()))?;
|
||||
let state_id = StateID::new(state_idx + nfa.start_anchored().as_usize())
|
||||
.map_err(|e| NFAError::InvalidStateId(e.to_string()))?;
|
||||
|
||||
match nfa.state(state_id) {
|
||||
State::Match { .. } => {
|
||||
self.accept_states.insert(state_idx);
|
||||
}
|
||||
State::ByteRange { trans } => {
|
||||
self.add_byte_range_transition(state_idx, trans)?;
|
||||
self.add_byte_range_transition(
|
||||
nfa.start_anchored().as_usize(),
|
||||
state_idx,
|
||||
trans,
|
||||
)?;
|
||||
}
|
||||
State::Sparse(sparse) => {
|
||||
self.add_sparse_transitions(state_idx, &sparse.transitions)?;
|
||||
self.add_sparse_transitions(
|
||||
nfa.start_anchored().as_usize(),
|
||||
state_idx,
|
||||
&sparse.transitions,
|
||||
)?;
|
||||
}
|
||||
State::Dense(dense) => {
|
||||
self.add_dense_transitions(state_idx, &dense.transitions)?;
|
||||
self.add_dense_transitions(
|
||||
nfa.start_anchored().as_usize(),
|
||||
state_idx,
|
||||
&dense.transitions,
|
||||
)?;
|
||||
}
|
||||
State::Union { alternates } => {
|
||||
self.add_union_transitions(state_idx, alternates)?;
|
||||
self.add_union_transitions(
|
||||
nfa.start_anchored().as_usize(),
|
||||
state_idx,
|
||||
alternates,
|
||||
)?;
|
||||
}
|
||||
State::BinaryUnion { alt1, alt2 } => {
|
||||
self.add_binary_union_transitions(state_idx, alt1, alt2)?;
|
||||
self.add_binary_union_transitions(
|
||||
nfa.start_anchored().as_usize(),
|
||||
state_idx,
|
||||
alt1,
|
||||
alt2,
|
||||
)?;
|
||||
}
|
||||
State::Capture {
|
||||
next,
|
||||
@@ -78,11 +99,17 @@ impl NFAGraph {
|
||||
slot,
|
||||
..
|
||||
} => {
|
||||
self.add_capture_transition(state_idx, next, group_index, slot)?;
|
||||
self.add_capture_transition(
|
||||
nfa.start_anchored().as_usize(),
|
||||
state_idx,
|
||||
next,
|
||||
group_index,
|
||||
slot,
|
||||
)?;
|
||||
self.num_capture_groups = self.num_capture_groups.max(group_index.as_usize());
|
||||
}
|
||||
State::Look { next, .. } => {
|
||||
self.add_look_transition(state_idx, next)?;
|
||||
self.add_look_transition(nfa.start_anchored().as_usize(), state_idx, next)?;
|
||||
}
|
||||
State::Fail => {} // No transitions needed
|
||||
}
|
||||
@@ -91,13 +118,18 @@ impl NFAGraph {
|
||||
}
|
||||
|
||||
/// Adds a byte range transition to the graph
|
||||
fn add_byte_range_transition(&mut self, state_id: usize, trans: &Transition) -> NFAResult<()> {
|
||||
fn add_byte_range_transition(
|
||||
&mut self,
|
||||
anchored_state_id: usize,
|
||||
state_id: usize,
|
||||
trans: &Transition,
|
||||
) -> NFAResult<()> {
|
||||
for byte in trans.start..=trans.end {
|
||||
self.nodes[state_id]
|
||||
.byte_transitions
|
||||
.entry(byte)
|
||||
.or_insert_with(BTreeSet::new)
|
||||
.insert(trans.next.as_usize() - 2);
|
||||
.insert(trans.next.as_usize() - anchored_state_id);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
@@ -105,66 +137,84 @@ impl NFAGraph {
|
||||
/// Adds transitions from a sparse transition set
|
||||
fn add_sparse_transitions(
|
||||
&mut self,
|
||||
anchored_state_id: usize,
|
||||
state_id: usize,
|
||||
transitions: &[Transition],
|
||||
) -> NFAResult<()> {
|
||||
for trans in transitions {
|
||||
self.add_byte_range_transition(state_id, trans)?;
|
||||
self.add_byte_range_transition(anchored_state_id, state_id, trans)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Adds transitions from a dense transition table
|
||||
fn add_dense_transitions(&mut self, state_id: usize, transitions: &[StateID]) -> NFAResult<()> {
|
||||
fn add_dense_transitions(
|
||||
&mut self,
|
||||
anchored_state_id: usize,
|
||||
state_id: usize,
|
||||
transitions: &[StateID],
|
||||
) -> NFAResult<()> {
|
||||
for (byte, &next) in transitions.iter().enumerate() {
|
||||
if next != StateID::ZERO {
|
||||
self.nodes[state_id]
|
||||
.byte_transitions
|
||||
.entry(byte as u8)
|
||||
.or_insert_with(BTreeSet::new)
|
||||
.insert(next.as_usize() - 2);
|
||||
.insert(next.as_usize() - anchored_state_id);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Adds epsilon transitions for a union state
|
||||
fn add_union_transitions(&mut self, state_id: usize, alternates: &[StateID]) -> NFAResult<()> {
|
||||
self.nodes[state_id]
|
||||
.epsilon_transitions
|
||||
.extend(alternates.iter().map(|id| id.as_usize() - 2));
|
||||
fn add_union_transitions(
|
||||
&mut self,
|
||||
anchored_state_id: usize,
|
||||
state_id: usize,
|
||||
alternates: &[StateID],
|
||||
) -> NFAResult<()> {
|
||||
self.nodes[state_id].epsilon_transitions.extend(
|
||||
alternates
|
||||
.iter()
|
||||
.map(|id| id.as_usize() - anchored_state_id),
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Adds epsilon transitions for a binary union state
|
||||
fn add_binary_union_transitions(
|
||||
&mut self,
|
||||
anchored_state_id: usize,
|
||||
state_id: usize,
|
||||
alt1: &StateID,
|
||||
alt2: &StateID,
|
||||
) -> NFAResult<()> {
|
||||
let node = &mut self.nodes[state_id];
|
||||
node.epsilon_transitions.insert(alt1.as_usize() - 2);
|
||||
node.epsilon_transitions.insert(alt2.as_usize() - 2);
|
||||
node.epsilon_transitions
|
||||
.insert(alt1.as_usize() - anchored_state_id);
|
||||
node.epsilon_transitions
|
||||
.insert(alt2.as_usize() - anchored_state_id);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Adds an epsilon transition with capture group information
|
||||
fn add_capture_transition(
|
||||
&mut self,
|
||||
anchored_state_id: usize,
|
||||
state_id: usize,
|
||||
next: &StateID,
|
||||
group_index: &SmallIndex,
|
||||
slot: &SmallIndex,
|
||||
) -> NFAResult<()> {
|
||||
let node = &mut self.nodes[state_id];
|
||||
node.epsilon_transitions.insert(next.as_usize() - 2);
|
||||
node.epsilon_transitions
|
||||
.insert(next.as_usize() - anchored_state_id);
|
||||
|
||||
let group_idx = group_index.as_usize();
|
||||
if group_idx > 0 {
|
||||
let is_start = slot.as_usize() % 2 == 0;
|
||||
node.capture_groups
|
||||
.entry(next.as_usize() - 2)
|
||||
.entry(next.as_usize() - anchored_state_id)
|
||||
.or_insert_with(BTreeSet::new)
|
||||
.insert((group_idx, is_start));
|
||||
}
|
||||
@@ -172,19 +222,18 @@ impl NFAGraph {
|
||||
}
|
||||
|
||||
/// Adds an epsilon transition for a look-around state
|
||||
fn add_look_transition(&mut self, state_id: usize, next: &StateID) -> NFAResult<()> {
|
||||
fn add_look_transition(
|
||||
&mut self,
|
||||
anchored_state_id: usize,
|
||||
state_id: usize,
|
||||
next: &StateID,
|
||||
) -> NFAResult<()> {
|
||||
self.nodes[state_id]
|
||||
.epsilon_transitions
|
||||
.insert(next.as_usize() - 2);
|
||||
.insert(next.as_usize() - anchored_state_id);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Sets the start states for the NFA
|
||||
fn set_start_states(&mut self, nfa: &NFA) {
|
||||
self.start_states
|
||||
.insert(nfa.start_anchored().as_usize() - 2);
|
||||
}
|
||||
|
||||
pub fn pretty_print(&self) {
|
||||
println!("\n=== NFA Graph ===");
|
||||
println!("Regex: {}", self.regex);
|
||||
|
||||
@@ -98,15 +98,48 @@ impl NFAGraph {
|
||||
}
|
||||
}
|
||||
|
||||
// Handle start states - only make byte transition states reachable via epsilon into start states
|
||||
for &start in &self.start_states {
|
||||
new_start_states.insert(start);
|
||||
// Handle start states
|
||||
// Preserve original start states to iterate over them
|
||||
let original_start_states_snapshot: BTreeSet<usize> =
|
||||
self.start_states.iter().copied().collect();
|
||||
new_start_states.clear();
|
||||
|
||||
for &r_state in &closures[start].states {
|
||||
if has_byte_transitions[r_state] {
|
||||
new_start_states.insert(r_state);
|
||||
for &orig_start in &original_start_states_snapshot {
|
||||
new_start_states.insert(orig_start); // The original start state is always kept
|
||||
|
||||
// Check if the closure of this original start state contains any START captures.
|
||||
// If so, we don't want to create alternative start points from within this closure,
|
||||
// as that might allow bypassing these essential start captures.
|
||||
let mut has_start_captures_in_orig_closure = false;
|
||||
if let Some(orig_closure) = closures.get(orig_start) {
|
||||
for &(_, (_group_id, is_start_event)) in &orig_closure.captures {
|
||||
if is_start_event {
|
||||
has_start_captures_in_orig_closure = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !has_start_captures_in_orig_closure {
|
||||
// If no start captures in orig_start's closure, it's safe to add
|
||||
// other states from its closure that have byte transitions as new start states.
|
||||
if let Some(orig_closure) = closures.get(orig_start) {
|
||||
for &r_state in &orig_closure.states {
|
||||
if r_state == orig_start {
|
||||
continue;
|
||||
}
|
||||
// Check if r_state (a state reachable via epsilon from orig_start)
|
||||
// itself is the source of a byte transition.
|
||||
// The has_byte_transitions vec was populated based on nodes[r_state].byte_transitions
|
||||
if r_state < has_byte_transitions.len() && has_byte_transitions[r_state] {
|
||||
new_start_states.insert(r_state);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// If has_start_captures_in_orig_closure is true, we *only* keep orig_start.
|
||||
// This forces paths through orig_start, ensuring its transitions (which will
|
||||
// have correctly accumulated these start captures) are used.
|
||||
}
|
||||
|
||||
// Apply changes
|
||||
|
||||
Reference in New Issue
Block a user