Implementing $k$-Means Clustering


In [1]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
In [2]:
X = load_iris().data
y = load_iris().target

print(X.shape)
print(np.unique(y))

plt.figure(dpi=150)
plt.scatter(X[:,2], X[:,3], marker='x', c=y)
plt.show()
(150, 4)
[0 1 2]
In [3]:
plt.figure(dpi=120)
plt.scatter(X[:,2], X[:,3], marker='x', c='k')
plt.show()


Implementing individual steps of $k$-Means Clustering

</br>

In [4]:
from scipy.spatial.distance import cdist
from sklearn.metrics import pairwise_distances

#X
k = 3

#1. Randomly initialize k cluster centers
idx = np.random.permutation(X.shape[0])[0:k]
centers = X[idx]

#2. repeat until convergence:
#    2(a). Calculate all distances between data points and cluster centers (n x k)
'''
dist = np.zeros((X.shape[0], k))
for i in range(X.shape[0]):
    for j in range(k):
        #dist[i,j] = ((X[i] - centers[j]) ** 2).sum()
        dist[i,j] = (X[i] - centers[j]).T @ (X[i] - centers[j])
print(dist)
'''
#dist = cdist(X, centers, metric='sqeuclidean')
dist = pairwise_distances(X, centers, metric='sqeuclidean')

#    2(b). Update cluster memberships: (n) integer values
'''
mem = np.zeros((X.shape[0]), dtype=int)
for i in range(X.shape[0]):
    minval = +np.inf
    minpos = -1
    for j in range(k):
        if minval > dist[i,j]:
            minval = dist[i,j]
            minpos = j
    mem[i] = minpos
    mem[i] = np.argmin(dist[i,:])
'''
mem = np.argmin(dist, axis=1)
#print(mem)

#    2(c). Update cluster centers: mean of the data points that have cluster membership 1 to that cluster
for j in range(k):
    centers[j] = np.mean(X[mem==j], axis=0)
#print(centers)
In [5]:
print(mem)
print(mem==1)
print(X[mem==1])
print(np.mean(X[mem==1], axis=0))
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 2 2 2 1 2 1 2 1 2 1 1 1 1 1 1 2 1 1 1 1 2 1 2 1
 1 2 2 2 1 1 1 1 1 2 1 1 2 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 1 2 2 2 2
 2 2 1 2 2 2 2 2 2 2 1 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
 2 2]
[False False False False False False False False False False False False
 False False False False False False False False False False False False
 False False False False False False False False False False False False
 False False False False False False False False False False False False
 False False False False False  True False  True False  True False  True
  True  True  True  True  True False  True  True  True  True False  True
 False  True  True False False False  True  True  True  True  True False
  True  True False  True  True  True  True  True  True  True  True  True
  True  True  True  True False False False False False False  True False
 False False False False False  True False False False False False False
 False  True False False False False False False False False False False
 False False False False False False False False False False False False
 False False False False False False]
[[5.5 2.3 4.  1.3]
 [5.7 2.8 4.5 1.3]
 [4.9 2.4 3.3 1. ]
 [5.2 2.7 3.9 1.4]
 [5.  2.  3.5 1. ]
 [5.9 3.  4.2 1.5]
 [6.  2.2 4.  1. ]
 [6.1 2.9 4.7 1.4]
 [5.6 2.9 3.6 1.3]
 [5.6 3.  4.5 1.5]
 [5.8 2.7 4.1 1. ]
 [6.2 2.2 4.5 1.5]
 [5.6 2.5 3.9 1.1]
 [6.1 2.8 4.  1.3]
 [6.1 2.8 4.7 1.2]
 [6.4 2.9 4.3 1.3]
 [6.  2.9 4.5 1.5]
 [5.7 2.6 3.5 1. ]
 [5.5 2.4 3.8 1.1]
 [5.5 2.4 3.7 1. ]
 [5.8 2.7 3.9 1.2]
 [5.4 3.  4.5 1.5]
 [6.  3.4 4.5 1.6]
 [6.3 2.3 4.4 1.3]
 [5.6 3.  4.1 1.3]
 [5.5 2.5 4.  1.3]
 [5.5 2.6 4.4 1.2]
 [6.1 3.  4.6 1.4]
 [5.8 2.6 4.  1.2]
 [5.  2.3 3.3 1. ]
 [5.6 2.7 4.2 1.3]
 [5.7 3.  4.2 1.2]
 [5.7 2.9 4.2 1.3]
 [6.2 2.9 4.3 1.3]
 [5.1 2.5 3.  1.1]
 [5.7 2.8 4.1 1.3]
 [4.9 2.5 4.5 1.7]
 [5.7 2.5 5.  2. ]
 [5.6 2.8 4.9 2. ]]
[5.68205128 2.67692308 4.13589744 1.30512821]


Putting it all together - the alternating-optimization loop for $k$-Means Clustering


In [6]:
from scipy.spatial.distance import cdist
from sklearn.metrics import pairwise_distances

def kmeans(X, k=3, max_iter=100, tol=1e-9):
    #1. Randomly initialize k cluster centers
    centers = X[np.random.permutation(X.shape[0])[0:k]]
    init_centers = np.array(centers)

    #2. repeat until convergence:
    for v_iter in range(max_iter):
        #    2(a). Calculate all distances between data points and cluster centers (n x k)
        dist = pairwise_distances(X, centers, metric='sqeuclidean')

        #    2(b). Update cluster memberships: (n) integer values
        mem = np.argmin(dist, axis=1)

        #    2(c). Update cluster centers: mean of the data points that have cluster membership 1 to that cluster
        prev_centers = np.array(centers)
        for j in range(k):
            centers[j] = np.mean(X[mem==j], axis=0)

        # Termination Criteria# in successive iterations, the change centers is negligible 
        if np.linalg.norm(centers - prev_centers) < tol:
            print('break at:', v_iter)
            break
    return mem, centers, init_centers
In [7]:
mem, centers, init_centers = kmeans(X, k=3)

plt.figure(dpi=150)
plt.scatter(X[:,2], X[:,3], marker='x', c=mem)
plt.scatter(init_centers[:,2], init_centers[:,3], marker='o', c='y')
plt.scatter(centers[:,2], centers[:,3], marker='o', c='r')
plt.show()
break at: 3


Poor Initialization leading to a non-ideal local optima


In [8]:
mem, centers, init_centers = kmeans(X, k=3)

plt.figure(dpi=150)
plt.scatter(X[:,2], X[:,3], marker='x', c=mem)
plt.scatter(init_centers[:,2], init_centers[:,3], marker='o', c='y')
plt.scatter(centers[:,2], centers[:,3], marker='o', c='r')
plt.show()
break at: 11