mirror of
https://github.com/AtsushiSakai/PythonRobotics.git
synced 2026-01-14 09:08:01 -05:00
first release k means simulation
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
"""
|
||||
|
||||
Object clustering with k-mean algorithm
|
||||
Object clustering with k-means algorithm
|
||||
|
||||
author: Atsushi Sakai (@Atsushi_twi)
|
||||
|
||||
@@ -25,9 +25,23 @@ class Clusters:
|
||||
self.cy = [0.0 for _ in range(nlabel)]
|
||||
|
||||
|
||||
def init_clusters(rx, ry, nc):
|
||||
def kmeans_clustering(rx, ry, nc):
|
||||
|
||||
clusters = Clusters(rx, ry, nc)
|
||||
clusters = calc_centroid(clusters)
|
||||
|
||||
MAX_LOOP = 10
|
||||
DCOST_TH = 0.1
|
||||
pcost = 100.0
|
||||
for loop in range(MAX_LOOP):
|
||||
# print("Loop:", loop)
|
||||
clusters, cost = update_clusters(clusters)
|
||||
clusters = calc_centroid(clusters)
|
||||
|
||||
dcost = abs(cost - pcost)
|
||||
if dcost < DCOST_TH:
|
||||
break
|
||||
pcost = cost
|
||||
|
||||
return clusters
|
||||
|
||||
@@ -62,44 +76,6 @@ def update_clusters(clusters):
|
||||
return clusters, cost
|
||||
|
||||
|
||||
def kmean_clustering(rx, ry, nc):
|
||||
|
||||
clusters = init_clusters(rx, ry, nc)
|
||||
clusters = calc_centroid(clusters)
|
||||
|
||||
MAX_LOOP = 10
|
||||
DCOST_TH = 1.0
|
||||
pcost = 100.0
|
||||
for loop in range(MAX_LOOP):
|
||||
print("Loop:", loop)
|
||||
clusters, cost = update_clusters(clusters)
|
||||
clusters = calc_centroid(clusters)
|
||||
|
||||
dcost = abs(cost - pcost)
|
||||
if dcost < DCOST_TH:
|
||||
break
|
||||
pcost = cost
|
||||
|
||||
return clusters
|
||||
|
||||
|
||||
def calc_raw_data():
|
||||
|
||||
rx, ry = [], []
|
||||
|
||||
cx = [0.0, 5.0]
|
||||
cy = [0.0, 5.0]
|
||||
npoints = 30
|
||||
rand_d = 3.0
|
||||
|
||||
for (icx, icy) in zip(cx, cy):
|
||||
for _ in range(npoints):
|
||||
rx.append(icx + rand_d * (random.random() - 0.5))
|
||||
ry.append(icy + rand_d * (random.random() - 0.5))
|
||||
|
||||
return rx, ry
|
||||
|
||||
|
||||
def calc_labeled_points(ic, clusters):
|
||||
|
||||
inds = np.array([i for i in range(clusters.ndata)
|
||||
@@ -113,19 +89,87 @@ def calc_labeled_points(ic, clusters):
|
||||
return x, y
|
||||
|
||||
|
||||
def calc_raw_data(cx, cy, npoints, rand_d):
|
||||
|
||||
rx, ry = [], []
|
||||
|
||||
for (icx, icy) in zip(cx, cy):
|
||||
for _ in range(npoints):
|
||||
rx.append(icx + rand_d * (random.random() - 0.5))
|
||||
ry.append(icy + rand_d * (random.random() - 0.5))
|
||||
|
||||
return rx, ry
|
||||
|
||||
|
||||
def update_positions(cx, cy):
|
||||
|
||||
DX1 = 0.4
|
||||
DY1 = 0.5
|
||||
|
||||
cx[0] += DX1
|
||||
cy[0] += DY1
|
||||
|
||||
DX2 = -0.3
|
||||
DY2 = -0.5
|
||||
|
||||
cx[1] += DX2
|
||||
cy[1] += DY2
|
||||
|
||||
return cx, cy
|
||||
|
||||
|
||||
def calc_association(cx, cy, clusters):
|
||||
|
||||
inds = []
|
||||
|
||||
for ic in range(len(cx)):
|
||||
tcx = cx[ic]
|
||||
tcy = cy[ic]
|
||||
|
||||
dx = [icx - tcx for icx in clusters.cx]
|
||||
dy = [icy - tcy for icy in clusters.cy]
|
||||
|
||||
dlist = [math.sqrt(idx**2 + idy**2) for (idx, idy) in zip(dx, dy)]
|
||||
min_id = dlist.index(min(dlist))
|
||||
inds.append(min_id)
|
||||
|
||||
return inds
|
||||
|
||||
|
||||
def main():
|
||||
print(__file__ + " start!!")
|
||||
|
||||
rx, ry = calc_raw_data()
|
||||
|
||||
cx = [0.0, 8.0]
|
||||
cy = [0.0, 8.0]
|
||||
npoints = 10
|
||||
rand_d = 3.0
|
||||
ncluster = 2
|
||||
clusters = kmean_clustering(rx, ry, ncluster)
|
||||
sim_time = 15.0
|
||||
dt = 1.0
|
||||
time = 0.0
|
||||
|
||||
for ic in range(clusters.nlabel):
|
||||
x, y = calc_labeled_points(ic, clusters)
|
||||
plt.plot(x, y, "x")
|
||||
plt.plot(clusters.cx, clusters.cy, "o")
|
||||
plt.show()
|
||||
while time <= sim_time:
|
||||
print("Time:", time)
|
||||
time += dt
|
||||
|
||||
# simulate objects
|
||||
cx, cy = update_positions(cx, cy)
|
||||
rx, ry = calc_raw_data(cx, cy, npoints, rand_d)
|
||||
|
||||
clusters = kmeans_clustering(rx, ry, ncluster)
|
||||
|
||||
# for animation
|
||||
plt.cla()
|
||||
inds = calc_association(cx, cy, clusters)
|
||||
for ic in inds:
|
||||
x, y = calc_labeled_points(ic, clusters)
|
||||
plt.plot(x, y, "x")
|
||||
plt.plot(cx, cy, "o")
|
||||
plt.xlim(-2.0, 10.0)
|
||||
plt.ylim(-2.0, 10.0)
|
||||
plt.pause(dt)
|
||||
|
||||
print("Done")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
Reference in New Issue
Block a user