Add expanded node set to SpaceTime AStar (#1183)

* speed up spacetime astar

* forgot to include hash impl on Position

* add condition to test on node expansions

* remove heuristic from Node __hash__ impl

* update rst with note about optimization
This commit is contained in:
Jonathan Schwartz
2025-03-13 10:24:53 -04:00
committed by GitHub
parent 73ebcd85dc
commit 1308e76424
4 changed files with 57 additions and 15 deletions

View File

@@ -30,6 +30,9 @@ class Position:
f"Subtraction not supported for Position and {type(other)}"
)
def __hash__(self):
return hash((self.x, self.y))
class ObstacleArrangement(Enum):
# Random obstacle positions and movements

View File

@@ -20,7 +20,7 @@ from collections.abc import Generator
import random
from dataclasses import dataclass
from functools import total_ordering
import time
# Seed randomness for reproducibility
RANDOM_SEED = 50
@@ -48,11 +48,17 @@ class Node:
return NotImplementedError(f"Cannot compare Node with object of type: {type(other)}")
return (self.time + self.heuristic) < (other.time + other.heuristic)
"""
Note: cost and heuristic are not included in eq or hash, since they will always be the same
for a given (position, time) pair. Including either cost or heuristic would be redundant.
"""
def __eq__(self, other: object):
if not isinstance(other, Node):
return NotImplementedError(f"Cannot compare Node with object of type: {type(other)}")
return self.position == other.position and self.time == other.time
def __hash__(self):
return hash((self.position, self.time))
class NodePath:
path: list[Node]
@@ -86,6 +92,8 @@ class SpaceTimeAStar:
grid: Grid
start: Position
goal: Position
# Used to evaluate solutions
expanded_node_count: int = -1
def __init__(self, grid: Grid, start: Position, goal: Position):
self.grid = grid
@@ -98,7 +106,8 @@ class SpaceTimeAStar:
open_set, Node(self.start, 0, self.calculate_heuristic(self.start), -1)
)
expanded_set: list[Node] = []
expanded_list: list[Node] = []
expanded_set: set[Node] = set()
while open_set:
expanded_node: Node = heapq.heappop(open_set)
if verbose:
@@ -110,23 +119,25 @@ class SpaceTimeAStar:
continue
if expanded_node.position == self.goal:
print(f"Found path to goal after {len(expanded_set)} expansions")
print(f"Found path to goal after {len(expanded_list)} expansions")
path = []
path_walker: Node = expanded_node
while True:
path.append(path_walker)
if path_walker.parent_index == -1:
break
path_walker = expanded_set[path_walker.parent_index]
path_walker = expanded_list[path_walker.parent_index]
# reverse path so it goes start -> goal
path.reverse()
self.expanded_node_count = len(expanded_set)
return NodePath(path)
expanded_idx = len(expanded_set)
expanded_set.append(expanded_node)
expanded_idx = len(expanded_list)
expanded_list.append(expanded_node)
expanded_set.add(expanded_node)
for child in self.generate_successors(expanded_node, expanded_idx, verbose):
for child in self.generate_successors(expanded_node, expanded_idx, verbose, expanded_set):
heapq.heappush(open_set, child)
raise Exception("No path found")
@@ -135,7 +146,7 @@ class SpaceTimeAStar:
Generate possible successors of the provided `parent_node`
"""
def generate_successors(
self, parent_node: Node, parent_node_idx: int, verbose: bool
self, parent_node: Node, parent_node_idx: int, verbose: bool, expanded_set: set[Node]
) -> Generator[Node, None, None]:
diffs = [
Position(0, 0),
@@ -146,13 +157,17 @@ class SpaceTimeAStar:
]
for diff in diffs:
new_pos = parent_node.position + diff
new_node = Node(
new_pos,
parent_node.time + 1,
self.calculate_heuristic(new_pos),
parent_node_idx,
)
if new_node in expanded_set:
continue
if self.grid.valid_position(new_pos, parent_node.time + 1):
new_node = Node(
new_pos,
parent_node.time + 1,
self.calculate_heuristic(new_pos),
parent_node_idx,
)
if verbose:
print("\tNew successor node: ", new_node)
yield new_node
@@ -166,9 +181,12 @@ show_animation = True
verbose = False
def main():
start = Position(1, 11)
start = Position(1, 5)
goal = Position(19, 19)
grid_side_length = 21
start_time = time.time()
grid = Grid(
np.array([grid_side_length, grid_side_length]),
num_obstacles=40,
@@ -179,6 +197,9 @@ def main():
planner = SpaceTimeAStar(grid, start, goal)
path = planner.plan(verbose)
runtime = time.time() - start_time
print(f"Planning took: {runtime:.5f} seconds")
if verbose:
print(f"Path: {path}")