mirror of
https://github.com/AtsushiSakai/PythonRobotics.git
synced 2026-04-22 03:00:22 -04:00
Spiral Spanning Tree Coverage Path Planning (#355)
* First commit of Spiral Spanning Tree Coverage * Modify followed by first code review * fix pycodestyle error * modifies following 2nd code review
This commit is contained in:
BIN
PathPlanning/SpiralSpanningTreeCPP/map/test.png
Normal file
BIN
PathPlanning/SpiralSpanningTreeCPP/map/test.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 156 B |
BIN
PathPlanning/SpiralSpanningTreeCPP/map/test_2.png
Normal file
BIN
PathPlanning/SpiralSpanningTreeCPP/map/test_2.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 150 B |
BIN
PathPlanning/SpiralSpanningTreeCPP/map/test_3.png
Normal file
BIN
PathPlanning/SpiralSpanningTreeCPP/map/test_3.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 132 B |
@@ -0,0 +1,313 @@
|
||||
"""
|
||||
Spiral Spanning Tree Coverage Path Planner
|
||||
|
||||
author: Todd Tang
|
||||
paper: Spiral-STC: An On-Line Coverage Algorithm of Grid Environments
|
||||
by a Mobile Robot - Gabriely et.al.
|
||||
link: https://ieeexplore.ieee.org/abstract/document/1013479
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
do_animation = True
|
||||
|
||||
|
||||
class SpiralSpanningTreeCoveragePlanner:
|
||||
def __init__(self, occ_map):
|
||||
self.origin_map_height = occ_map.shape[0]
|
||||
self.origin_map_width = occ_map.shape[1]
|
||||
|
||||
# original map resolution must be even
|
||||
if self.origin_map_height % 2 == 1 or self.origin_map_width % 2 == 1:
|
||||
sys.exit('original map width/height must be even \
|
||||
in grayscale .png format')
|
||||
|
||||
self.occ_map = occ_map
|
||||
self.merged_map_height = self.origin_map_height // 2
|
||||
self.merged_map_width = self.origin_map_width // 2
|
||||
|
||||
self.edge = []
|
||||
|
||||
def plan(self, start):
|
||||
"""plan
|
||||
|
||||
performing Spiral Spanning Tree Coverage path planning
|
||||
|
||||
:param start: the start node of Spiral Spanning Tree Coverage
|
||||
"""
|
||||
|
||||
visit_times = np.zeros(
|
||||
(self.merged_map_height, self.merged_map_width), dtype=np.int)
|
||||
visit_times[start[0]][start[1]] = 1
|
||||
|
||||
# generate route by
|
||||
# recusively call perform_spanning_tree_coverage() from start node
|
||||
route = []
|
||||
self.perform_spanning_tree_coverage(start, visit_times, route)
|
||||
|
||||
path = []
|
||||
# generate path from route
|
||||
for idx in range(len(route)-1):
|
||||
dp = abs(route[idx][0] - route[idx+1][0]) + \
|
||||
abs(route[idx][1] - route[idx+1][1])
|
||||
if dp == 0:
|
||||
# special handle for round-trip path
|
||||
path.append(self.get_round_trip_path(route[idx-1], route[idx]))
|
||||
elif dp == 1:
|
||||
path.append(self.move(route[idx], route[idx+1]))
|
||||
elif dp == 2:
|
||||
# special handle for non-adjacent route nodes
|
||||
mid_node = self.get_intermediate_node(route[idx], route[idx+1])
|
||||
path.append(self.move(route[idx], mid_node))
|
||||
path.append(self.move(mid_node, route[idx+1]))
|
||||
else:
|
||||
sys.exit('adjacent path node distance larger than 2')
|
||||
|
||||
return self.edge, route, path
|
||||
|
||||
def perform_spanning_tree_coverage(self, current_node, visit_times, route):
|
||||
"""perform_spanning_tree_coverage
|
||||
|
||||
recursive function for function <plan>
|
||||
|
||||
:param current_node: current node
|
||||
"""
|
||||
|
||||
def is_valid_node(i, j):
|
||||
is_i_valid_bounded = 0 <= i < self.merged_map_height
|
||||
is_j_valid_bounded = 0 <= j < self.merged_map_width
|
||||
if is_i_valid_bounded and is_j_valid_bounded:
|
||||
# free only when the 4 sub-cells are all free
|
||||
return bool(
|
||||
self.occ_map[2*i][2*j]
|
||||
and self.occ_map[2*i+1][2*j]
|
||||
and self.occ_map[2*i][2*j+1]
|
||||
and self.occ_map[2*i+1][2*j+1])
|
||||
|
||||
return False
|
||||
|
||||
# counter-clockwise neighbor finding order
|
||||
order = [[1, 0], [0, 1], [-1, 0], [0, -1]]
|
||||
|
||||
found = False
|
||||
route.append(current_node)
|
||||
for inc in order:
|
||||
ni, nj = current_node[0] + inc[0], current_node[1] + inc[1]
|
||||
if is_valid_node(ni, nj) and visit_times[ni][nj] == 0:
|
||||
neighbor_node = (ni, nj)
|
||||
self.edge.append((current_node, neighbor_node))
|
||||
found = True
|
||||
visit_times[ni][nj] += 1
|
||||
self.perform_spanning_tree_coverage(
|
||||
neighbor_node, visit_times, route)
|
||||
|
||||
# backtrace route from node with neighbors all visited
|
||||
# to first node with unvisited neighbor
|
||||
if not found:
|
||||
has_node_with_unvisited_ngb = False
|
||||
for node in reversed(route):
|
||||
# drop nodes that have been visited twice
|
||||
if visit_times[node[0]][node[1]] == 2:
|
||||
continue
|
||||
|
||||
visit_times[node[0]][node[1]] += 1
|
||||
route.append(node)
|
||||
|
||||
for inc in order:
|
||||
ni, nj = node[0] + inc[0], node[1] + inc[1]
|
||||
if is_valid_node(ni, nj) and visit_times[ni][nj] == 0:
|
||||
has_node_with_unvisited_ngb = True
|
||||
break
|
||||
|
||||
if has_node_with_unvisited_ngb:
|
||||
break
|
||||
|
||||
return route
|
||||
|
||||
def move(self, p, q):
|
||||
direction = self.get_vector_direction(p, q)
|
||||
# move east
|
||||
if direction == 'E':
|
||||
p = self.get_sub_node(p, 'SE')
|
||||
q = self.get_sub_node(q, 'SW')
|
||||
# move west
|
||||
elif direction == 'W':
|
||||
p = self.get_sub_node(p, 'NW')
|
||||
q = self.get_sub_node(q, 'NE')
|
||||
# move south
|
||||
elif direction == 'S':
|
||||
p = self.get_sub_node(p, 'SW')
|
||||
q = self.get_sub_node(q, 'NW')
|
||||
# move north
|
||||
elif direction == 'N':
|
||||
p = self.get_sub_node(p, 'NE')
|
||||
q = self.get_sub_node(q, 'SE')
|
||||
else:
|
||||
sys.exit('move direction error...')
|
||||
return [p, q]
|
||||
|
||||
def get_round_trip_path(self, last, pivot):
|
||||
direction = self.get_vector_direction(last, pivot)
|
||||
if direction == 'E':
|
||||
return [self.get_sub_node(pivot, 'SE'),
|
||||
self.get_sub_node(pivot, 'NE')]
|
||||
elif direction == 'S':
|
||||
return [self.get_sub_node(pivot, 'SW'),
|
||||
self.get_sub_node(pivot, 'SE')]
|
||||
elif direction == 'W':
|
||||
return [self.get_sub_node(pivot, 'NW'),
|
||||
self.get_sub_node(pivot, 'SW')]
|
||||
elif direction == 'N':
|
||||
return [self.get_sub_node(pivot, 'NE'),
|
||||
self.get_sub_node(pivot, 'NW')]
|
||||
else:
|
||||
sys.exit('get_round_trip_path: last->pivot direction error.')
|
||||
|
||||
def get_vector_direction(self, p, q):
|
||||
# east
|
||||
if p[0] == q[0] and p[1] < q[1]:
|
||||
return 'E'
|
||||
# west
|
||||
elif p[0] == q[0] and p[1] > q[1]:
|
||||
return 'W'
|
||||
# south
|
||||
elif p[0] < q[0] and p[1] == q[1]:
|
||||
return 'S'
|
||||
# north
|
||||
elif p[0] > q[0] and p[1] == q[1]:
|
||||
return 'N'
|
||||
else:
|
||||
sys.exit('get_vector_direction: Only E/W/S/N direction supported.')
|
||||
|
||||
def get_sub_node(self, node, direction):
|
||||
if direction == 'SE':
|
||||
return [2*node[0]+1, 2*node[1]+1]
|
||||
elif direction == 'SW':
|
||||
return [2*node[0]+1, 2*node[1]]
|
||||
elif direction == 'NE':
|
||||
return [2*node[0], 2*node[1]+1]
|
||||
elif direction == 'NW':
|
||||
return [2*node[0], 2*node[1]]
|
||||
else:
|
||||
sys.exit('get_sub_node: sub-node direction error.')
|
||||
|
||||
def get_interpolated_path(self, p, q):
|
||||
# direction p->q: southwest / northeast
|
||||
if (p[0] < q[0]) ^ (p[1] < q[1]):
|
||||
ipx = [p[0], p[0], q[0]]
|
||||
ipy = [p[1], q[1], q[1]]
|
||||
# direction p->q: southeast / northwest
|
||||
else:
|
||||
ipx = [p[0], q[0], q[0]]
|
||||
ipy = [p[1], p[1], q[1]]
|
||||
return ipx, ipy
|
||||
|
||||
def get_intermediate_node(self, p, q):
|
||||
p_ngb, q_ngb = set(), set()
|
||||
|
||||
for m, n in self.edge:
|
||||
if m == p:
|
||||
p_ngb.add(n)
|
||||
if n == p:
|
||||
p_ngb.add(m)
|
||||
if m == q:
|
||||
q_ngb.add(n)
|
||||
if n == q:
|
||||
q_ngb.add(m)
|
||||
|
||||
itsc = p_ngb.intersection(q_ngb)
|
||||
if len(itsc) == 0:
|
||||
sys.exit('get_intermediate_node: \
|
||||
no intermediate node between', p, q)
|
||||
elif len(itsc) == 1:
|
||||
return list(itsc)[0]
|
||||
else:
|
||||
sys.exit('get_intermediate_node: \
|
||||
more than 1 intermediate node between', p, q)
|
||||
|
||||
def visualize_path(self, edge, path, start):
|
||||
def coord_transform(p):
|
||||
return [2*p[1] + 0.5, 2*p[0] + 0.5]
|
||||
|
||||
if do_animation:
|
||||
last = path[0][0]
|
||||
trajectory = [[last[1]], [last[0]]]
|
||||
for p, q in path:
|
||||
distance = math.hypot(p[0]-last[0], p[1]-last[1])
|
||||
if distance <= 1.0:
|
||||
trajectory[0].append(p[1])
|
||||
trajectory[1].append(p[0])
|
||||
else:
|
||||
ipx, ipy = self.get_interpolated_path(last, p)
|
||||
trajectory[0].extend(ipy)
|
||||
trajectory[1].extend(ipx)
|
||||
|
||||
last = q
|
||||
|
||||
trajectory[0].append(last[1])
|
||||
trajectory[1].append(last[0])
|
||||
|
||||
for idx, state in enumerate(np.transpose(trajectory)):
|
||||
plt.cla()
|
||||
# for stopping simulation with the esc key.
|
||||
plt.gcf().canvas.mpl_connect(
|
||||
'key_release_event',
|
||||
lambda event: [exit(0) if event.key == 'escape' else None])
|
||||
|
||||
# draw spanning tree
|
||||
plt.imshow(self.occ_map, 'gray')
|
||||
for p, q in edge:
|
||||
p = coord_transform(p)
|
||||
q = coord_transform(q)
|
||||
plt.plot([p[0], q[0]], [p[1], q[1]], '-oc')
|
||||
sx, sy = coord_transform(start)
|
||||
plt.plot([sx], [sy], 'pr', markersize=10)
|
||||
|
||||
# draw move path
|
||||
plt.plot(trajectory[0][:idx+1], trajectory[1][:idx+1], '-k')
|
||||
plt.plot(state[0], state[1], 'or')
|
||||
plt.axis('equal')
|
||||
plt.grid(True)
|
||||
plt.pause(0.01)
|
||||
|
||||
else:
|
||||
# draw spanning tree
|
||||
plt.imshow(self.occ_map, 'gray')
|
||||
for p, q in edge:
|
||||
p = coord_transform(p)
|
||||
q = coord_transform(q)
|
||||
plt.plot([p[0], q[0]], [p[1], q[1]], '-oc')
|
||||
sx, sy = coord_transform(start)
|
||||
plt.plot([sx], [sy], 'pr', markersize=10)
|
||||
|
||||
# draw move path
|
||||
last = path[0][0]
|
||||
for p, q in path:
|
||||
distance = math.hypot(p[0]-last[0], p[1]-last[1])
|
||||
if distance == 1.0:
|
||||
plt.plot([last[1], p[1]], [last[0], p[0]], '-k')
|
||||
else:
|
||||
ipx, ipy = self.get_interpolated_path(last, p)
|
||||
plt.plot(ipy, ipx, '-k')
|
||||
plt.arrow(p[1], p[0], q[1]-p[1], q[0]-p[0], head_width=0.2)
|
||||
last = q
|
||||
|
||||
plt.show()
|
||||
|
||||
|
||||
def main():
|
||||
dir_path = os.path.dirname(os.path.realpath(__file__))
|
||||
img = plt.imread(os.path.join(dir_path, 'map', 'test_2.png'))
|
||||
STC_planner = SpiralSpanningTreeCoveragePlanner(img)
|
||||
start = (10, 0)
|
||||
edge, route, path = STC_planner.plan(start)
|
||||
STC_planner.visualize_path(edge, path, start)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
58
tests/test_spiral_spanning_tree_coverage_path_planner.py
Normal file
58
tests/test_spiral_spanning_tree_coverage_path_planner.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import os
|
||||
import sys
|
||||
import matplotlib.pyplot as plt
|
||||
from unittest import TestCase
|
||||
|
||||
sys.path.append(os.path.dirname(
|
||||
os.path.abspath(__file__)) + "/../PathPlanning/SpiralSpanningTreeCPP")
|
||||
try:
|
||||
import spiral_spanning_tree_coverage_path_planner
|
||||
except ImportError:
|
||||
raise
|
||||
|
||||
spiral_spanning_tree_coverage_path_planner.do_animation = True
|
||||
|
||||
|
||||
class TestPlanning(TestCase):
|
||||
def spiral_stc_cpp(self, img, start):
|
||||
num_free = 0
|
||||
for i in range(img.shape[0]):
|
||||
for j in range(img.shape[1]):
|
||||
num_free += img[i][j]
|
||||
|
||||
STC_planner = spiral_spanning_tree_coverage_path_planner.\
|
||||
SpiralSpanningTreeCoveragePlanner(img)
|
||||
|
||||
edge, route, path = STC_planner.plan(start)
|
||||
|
||||
covered_nodes = set()
|
||||
for p, q in edge:
|
||||
covered_nodes.add(p)
|
||||
covered_nodes.add(q)
|
||||
|
||||
# assert complete coverage
|
||||
self.assertEqual(len(covered_nodes), num_free / 4)
|
||||
|
||||
def test_spiral_stc_cpp_1(self):
|
||||
img_dir = os.path.dirname(
|
||||
os.path.abspath(__file__)) + \
|
||||
"/../PathPlanning/SpiralSpanningTreeCPP"
|
||||
img = plt.imread(os.path.join(img_dir, 'map', 'test.png'))
|
||||
start = (0, 0)
|
||||
self.spiral_stc_cpp(img, start)
|
||||
|
||||
def test_spiral_stc_cpp_2(self):
|
||||
img_dir = os.path.dirname(
|
||||
os.path.abspath(__file__)) + \
|
||||
"/../PathPlanning/SpiralSpanningTreeCPP"
|
||||
img = plt.imread(os.path.join(img_dir, 'map', 'test_2.png'))
|
||||
start = (10, 0)
|
||||
self.spiral_stc_cpp(img, start)
|
||||
|
||||
def test_spiral_stc_cpp_3(self):
|
||||
img_dir = os.path.dirname(
|
||||
os.path.abspath(__file__)) + \
|
||||
"/../PathPlanning/SpiralSpanningTreeCPP"
|
||||
img = plt.imread(os.path.join(img_dir, 'map', 'test_3.png'))
|
||||
start = (0, 0)
|
||||
self.spiral_stc_cpp(img, start)
|
||||
Reference in New Issue
Block a user