start implementing.

This commit is contained in:
Atsushi Sakai
2019-07-19 20:35:55 +09:00
parent 9b8f2bd88a
commit 8ff5df7463

View File

@@ -47,7 +47,7 @@ class RRT:
self.expand_dis = expand_dis
self.goal_sample_rate = goal_sample_rate
self.max_iter = max_iter
self.obstacleList = obstacle_list
self.obstacle_list = obstacle_list
self.node_list = []
def planning(self, animation=True):
@@ -59,14 +59,14 @@ class RRT:
self.node_list = [self.start]
for i in range(self.max_iter):
rnd = self.get_random_point()
nearest_ind = self.get_nearest_list_index(self.node_list, rnd)
rnd_node = self.get_random_node()
nearest_ind = self.get_nearest_list_index(self.node_list, rnd_node)
nearest_node = self.node_list[nearest_ind]
new_node = self.steer(rnd, nearest_node)
new_node = self.steer(nearest_node, rnd_node)
new_node.parent = nearest_node
if not self.check_collision(new_node, self.obstacleList):
if not self.check_collision(new_node, self.obstacle_list):
continue
self.node_list.append(new_node)
@@ -78,16 +78,21 @@ class RRT:
return self.generate_final_course(len(self.node_list) - 1)
if animation and i % 5:
self.draw_graph(rnd)
self.draw_graph(rnd_node)
return None # cannot find path
def steer(self, rnd, nearest_node):
new_node = self.Node(rnd[0], rnd[1])
d, theta = self.calc_distance_and_angle(nearest_node, new_node)
def steer(self, from_node, to_node):
d, theta = self.calc_distance_and_angle(from_node, to_node)
if d > self.expand_dis:
new_node.x = nearest_node.x + self.expand_dis * math.cos(theta)
new_node.y = nearest_node.y + self.expand_dis * math.sin(theta)
x = from_node.x + self.expand_dis * math.cos(theta)
y = from_node.y + self.expand_dis * math.sin(theta)
else:
x = to_node.x
y = to_node.y
new_node = self.Node(x, y)
new_node.parent = from_node
return new_node
@@ -106,25 +111,25 @@ class RRT:
dy = y - self.end.y
return math.sqrt(dx ** 2 + dy ** 2)
def get_random_point(self):
def get_random_node(self):
if random.randint(0, 100) > self.goal_sample_rate:
rnd = [random.uniform(self.min_rand, self.max_rand),
random.uniform(self.min_rand, self.max_rand)]
rnd = self.Node(random.uniform(self.min_rand, self.max_rand),
random.uniform(self.min_rand, self.max_rand))
else: # goal point sampling
rnd = [self.end.x, self.end.y]
rnd = self.Node(self.end.x, self.end.y)
return rnd
def draw_graph(self, rnd=None):
plt.clf()
if rnd is not None:
plt.plot(rnd[0], rnd[1], "^k")
plt.plot(rnd.x, rnd.y, "^k")
for node in self.node_list:
if node.parent:
plt.plot([node.x, node.parent.x],
[node.y, node.parent.y],
"-g")
for (ox, oy, size) in self.obstacleList:
for (ox, oy, size) in self.obstacle_list:
plt.plot(ox, oy, "ok", ms=30 * size)
plt.plot(self.start.x, self.start.y, "xr")
@@ -134,8 +139,8 @@ class RRT:
plt.pause(0.01)
@staticmethod
def get_nearest_list_index(node_list, rnd):
dlist = [(node.x - rnd[0]) ** 2 + (node.y - rnd[1])
def get_nearest_list_index(node_list, rnd_node):
dlist = [(node.x - rnd_node.x) ** 2 + (node.y - rnd_node.y)
** 2 for node in node_list]
minind = dlist.index(min(dlist))