mirror of
https://github.com/AtsushiSakai/PythonRobotics.git
synced 2026-04-22 03:00:22 -04:00
Add expanded node set to SpaceTime AStar (#1183)
* speed up spacetime astar * forgot to include hash impl on Position * add condition to test on node expansions * remove heuristic from Node __hash__ impl * update rst with note about optimization
This commit is contained in:
committed by
GitHub
parent
73ebcd85dc
commit
1308e76424
@@ -30,6 +30,9 @@ class Position:
|
||||
f"Subtraction not supported for Position and {type(other)}"
|
||||
)
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.x, self.y))
|
||||
|
||||
|
||||
class ObstacleArrangement(Enum):
|
||||
# Random obstacle positions and movements
|
||||
|
||||
@@ -20,7 +20,7 @@ from collections.abc import Generator
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from functools import total_ordering
|
||||
|
||||
import time
|
||||
|
||||
# Seed randomness for reproducibility
|
||||
RANDOM_SEED = 50
|
||||
@@ -48,11 +48,17 @@ class Node:
|
||||
return NotImplementedError(f"Cannot compare Node with object of type: {type(other)}")
|
||||
return (self.time + self.heuristic) < (other.time + other.heuristic)
|
||||
|
||||
"""
|
||||
Note: cost and heuristic are not included in eq or hash, since they will always be the same
|
||||
for a given (position, time) pair. Including either cost or heuristic would be redundant.
|
||||
"""
|
||||
def __eq__(self, other: object):
|
||||
if not isinstance(other, Node):
|
||||
return NotImplementedError(f"Cannot compare Node with object of type: {type(other)}")
|
||||
return self.position == other.position and self.time == other.time
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.position, self.time))
|
||||
|
||||
class NodePath:
|
||||
path: list[Node]
|
||||
@@ -86,6 +92,8 @@ class SpaceTimeAStar:
|
||||
grid: Grid
|
||||
start: Position
|
||||
goal: Position
|
||||
# Used to evaluate solutions
|
||||
expanded_node_count: int = -1
|
||||
|
||||
def __init__(self, grid: Grid, start: Position, goal: Position):
|
||||
self.grid = grid
|
||||
@@ -98,7 +106,8 @@ class SpaceTimeAStar:
|
||||
open_set, Node(self.start, 0, self.calculate_heuristic(self.start), -1)
|
||||
)
|
||||
|
||||
expanded_set: list[Node] = []
|
||||
expanded_list: list[Node] = []
|
||||
expanded_set: set[Node] = set()
|
||||
while open_set:
|
||||
expanded_node: Node = heapq.heappop(open_set)
|
||||
if verbose:
|
||||
@@ -110,23 +119,25 @@ class SpaceTimeAStar:
|
||||
continue
|
||||
|
||||
if expanded_node.position == self.goal:
|
||||
print(f"Found path to goal after {len(expanded_set)} expansions")
|
||||
print(f"Found path to goal after {len(expanded_list)} expansions")
|
||||
path = []
|
||||
path_walker: Node = expanded_node
|
||||
while True:
|
||||
path.append(path_walker)
|
||||
if path_walker.parent_index == -1:
|
||||
break
|
||||
path_walker = expanded_set[path_walker.parent_index]
|
||||
path_walker = expanded_list[path_walker.parent_index]
|
||||
|
||||
# reverse path so it goes start -> goal
|
||||
path.reverse()
|
||||
self.expanded_node_count = len(expanded_set)
|
||||
return NodePath(path)
|
||||
|
||||
expanded_idx = len(expanded_set)
|
||||
expanded_set.append(expanded_node)
|
||||
expanded_idx = len(expanded_list)
|
||||
expanded_list.append(expanded_node)
|
||||
expanded_set.add(expanded_node)
|
||||
|
||||
for child in self.generate_successors(expanded_node, expanded_idx, verbose):
|
||||
for child in self.generate_successors(expanded_node, expanded_idx, verbose, expanded_set):
|
||||
heapq.heappush(open_set, child)
|
||||
|
||||
raise Exception("No path found")
|
||||
@@ -135,7 +146,7 @@ class SpaceTimeAStar:
|
||||
Generate possible successors of the provided `parent_node`
|
||||
"""
|
||||
def generate_successors(
|
||||
self, parent_node: Node, parent_node_idx: int, verbose: bool
|
||||
self, parent_node: Node, parent_node_idx: int, verbose: bool, expanded_set: set[Node]
|
||||
) -> Generator[Node, None, None]:
|
||||
diffs = [
|
||||
Position(0, 0),
|
||||
@@ -146,13 +157,17 @@ class SpaceTimeAStar:
|
||||
]
|
||||
for diff in diffs:
|
||||
new_pos = parent_node.position + diff
|
||||
new_node = Node(
|
||||
new_pos,
|
||||
parent_node.time + 1,
|
||||
self.calculate_heuristic(new_pos),
|
||||
parent_node_idx,
|
||||
)
|
||||
|
||||
if new_node in expanded_set:
|
||||
continue
|
||||
|
||||
if self.grid.valid_position(new_pos, parent_node.time + 1):
|
||||
new_node = Node(
|
||||
new_pos,
|
||||
parent_node.time + 1,
|
||||
self.calculate_heuristic(new_pos),
|
||||
parent_node_idx,
|
||||
)
|
||||
if verbose:
|
||||
print("\tNew successor node: ", new_node)
|
||||
yield new_node
|
||||
@@ -166,9 +181,12 @@ show_animation = True
|
||||
verbose = False
|
||||
|
||||
def main():
|
||||
start = Position(1, 11)
|
||||
start = Position(1, 5)
|
||||
goal = Position(19, 19)
|
||||
grid_side_length = 21
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
grid = Grid(
|
||||
np.array([grid_side_length, grid_side_length]),
|
||||
num_obstacles=40,
|
||||
@@ -179,6 +197,9 @@ def main():
|
||||
planner = SpaceTimeAStar(grid, start, goal)
|
||||
path = planner.plan(verbose)
|
||||
|
||||
runtime = time.time() - start_time
|
||||
print(f"Planning took: {runtime:.5f} seconds")
|
||||
|
||||
if verbose:
|
||||
print(f"Path: {path}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user