K-means clustering is a basic unsupervised learning technique. In this post we will look at implementing it in different ways, and optimizing it on PyTorch. Some of the dependencies are numpy, matplotlib, torch and scikit-learn.
First we import the necessary libraries, and write a function to generate the data for the clustering, and a function to plot it
import torch, random, time
import numpy as np
import tqdm as tqdm
import matplotlib.pyplot as plt
def set_seeds(seed=0):
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
def plotter(groups, centroids, outimg):
for group, centroid in zip(groups, centroids):
group = np.array(group)
joint = np.concatenate([group,[centroid]])
plt.scatter(joint[:,0], joint[:,1], s=[2]*len(group)+[50])
plt.savefig(outimg)
plt.clf()
def gen_data(numpts, numcentroids):
centroids = np.random.rand(numcentroids, 2)
num_pts_in_clusters = [int(numpts/numcentroids)] * numcentroids
delta = sum(num_pts_in_clusters) - numpts
num_pts_in_clusters[-1] -= delta
assert sum(num_pts_in_clusters) == numpts
groups = [np.stack([np.random.normal(mn,[0.07,0.07]) for j in range(numpts)]) \
for numpts, mn in zip(num_pts_in_clusters, centroids)]
idx = np.random.choice(numpts, numcentroids, replace=False)
pts = np.concatenate(groups)
return pts, pts[idx], groups, centroids
Now lets call the above code and generate a sample dataset and view it
num_steps = 100
pts, start_centroids, groups, centroids = gen_data(1000, 10)
plotter(groups, centroids, 'ground_truth.png')
def time_repeat(count=1, warmup=0):
assert warmup < count
def timer(f):
def helper(*args, **kwargs):
tot_time = 0
for i in range(count):
t0 = time.time()
out = f(*args, **kwargs)
print(f'{i}/{count} in time_repeat {(time.time() - t0):.4f}s')
if i >= warmup:
tot_time += (time.time() - t0)
print(f'Timed function {count-warmup} times. tot time {tot_time}s. avg time {tot_time/(count-warmup):.4f}s')
return out
return helper
return timer
Lets also implement a decorator that times functions by running it multiple times
First we implement in numpy using double for loops, which is expected to be very slow.
@time_repeat(5, 1)
def kmeans_naive(pts, centroids, steps):
def dist(x,y):
return sum((x-y) * (x-y))
# the last step we just calculate the group, and we dont recalc the centroids
for step in range(steps+1):
groups = [[] for _ in range(len(centroids))]
# find closest centroid for each point
for ptidx, pt in enumerate(pts):
centroid_dists = [dist(centroid, pt) for centroid in centroids]
groups[np.argmin(centroid_dists)] += [pt]
if step == steps:
break
# calculate centroid based on points in each group
centroids = np.array([np.mean(np.array(group), 0) for group in groups])
return groups, centroids
naive_groups, naive_centroids = kmeans_naive(pts, start_centroids, num_steps)
plotter(naive_groups, naive_centroids, 'naive.png')
Average run time for this implementation is: 2.3322s
Now lets implement it in PyTorch
@time_repeat(5, 1)
def kmeans_pt(pts, centroids, steps):
pts = torch.tensor(pts) # N x 2
centroids = torch.tensor(centroids) # C x 2
for step in range(steps+1):
# N x 1 x 2 - 1 X C x 2 -> N x C x 2 by broadcasting
diffs = pts.unsqueeze(1) - centroids.unsqueeze(0)
diffs_sq = diffs * diffs
diffs_sq_sum = torch.sum(diffs_sq, 2) # N x C
clusters = torch.argmin(diffs_sq_sum, 1) # N
groups = [pts[clusters == i] for i in range(len(centroids))]
if step == steps:
break
centroids = torch.stack([torch.mean(group, 0) for group in groups])
return [group.numpy() for group in groups], centroids.numpy()
pt_groups, pt_centroids = kmeans_pt(pts, start_centroids, num_steps)
plotter(pt_groups, pt_centroids, 'pt.png')
Average run time for this implementation is: 0.0473s
This is much faster than the naive implementation.
We can try applying jit on the static portions of the code and see if it gains us anything
@torch.jit.script
def helper(pts, centroids):
diffs = pts.unsqueeze(1) - centroids.unsqueeze(0)
diffs_sq = diffs * diffs
diffs_sq_sum = torch.sum(diffs_sq, 2) # N x C
return torch.argmin(diffs_sq_sum, 1) # N
@time_repeat(5, 1)
def kmeans_pt1(pts, centroids, steps):
pts = torch.tensor(pts) # N x 2
centroids = torch.tensor(centroids) # C x 2
for step in range(steps+1):
# N x 1 x 2 - 1 X C x 2 -> N x C x 2 by broadcasting
clusters = helper(pts, centroids)
groups = [pts[clusters == i] for i in range(len(centroids))]
if step == steps:
break
centroids = torch.stack([torch.mean(group, 0) for group in groups])
return [group.numpy() for group in groups], centroids.numpy()
pt1_groups, pt1_centroids = kmeans_pt1(pts, start_centroids, num_steps)
plotter(pt1_groups, pt1_centroids, 'pt1.png')
Average run time for this implementation is: 0.0470s
JIT-ing did not gain us anything.
If we look at the current code, the computation of the groups and the centroid from the groups looks time consuming, because we use python list instead of directly using pytorch.
Lets try to analyse these 2 lines with a small example:
groups = [pts[clusters == i] for i in range(len(centroids))]
centroids = torch.stack([torch.mean(group, 0) for group in groups])
Consider the image on the right, which illustrates the current code for a small example. groups is a python list, hence we are outside PyTorch. Also the sizes of individual tensors in groups is most likely different in every step. This might be a performance bottleneck in certain hardwares like TPUs, as mentioned here.
Now consider the matrix multiplication shown on the left. This essentially adds the groups together, implementing something like this.
Finally if we weight each row of the matrix appropriately, we can implement group-by mean.
@time_repeat(5, 1)
def kmeans_pt2(pts, centroids, steps):
pts = torch.tensor(pts, dtype=torch.float32) # N x 2
centroids = torch.tensor(centroids, dtype=torch.float32) # C x 2
N = len(pts)
range_n = torch.arange(N)
for step in range(steps+1):
# N x 1 x 2 - 1 X C x 2 -> N x C x 2 by broadcasting
diffs = pts.unsqueeze(1) - centroids.unsqueeze(0)
diffs_sq = diffs * diffs
diffs_sq_sum = torch.sum(diffs_sq, 2) # N x C
clusters = torch.argmin(diffs_sq_sum, 1) # N
mtx = torch.zeros(clusters.max()+1, N)
mtx[clusters, range_n] = 1
mtx_norm = torch.nn.functional.normalize(mtx, p=1, dim=1)
if step == steps:
break
centroids = torch.mm(mtx_norm, pts)
return [pts[clusters == i].numpy() for i in range(len(centroids))], centroids.numpy()
pt2_groups, pt2_centroids = kmeans_pt2(pts, start_centroids, num_steps)
plotter(pt2_groups, pt2_centroids, 'pt2.png')
Average run time for this implementation is: 0.0251s
Thus we see a significant speedup with this approach.
Next lets parcel out the static sections inside the for loop and JIT them
@torch.jit.script
def helper2(pts, centroids):
# N x 1 x 2 - 1 X C x 2 -> N x C x 2 by broadcasting
diffs = pts.unsqueeze(1) - centroids.unsqueeze(0)
diffs_sq = diffs * diffs
diffs_sq_sum = torch.sum(diffs_sq, 2) # N x C
return torch.argmin(diffs_sq_sum, 1) # N
@torch.jit.script
def helper3(pts, centroids, range_n, N: int):
# N x 1 x 2 - 1 X C x 2 -> N x C x 2 by broadcasting
diffs = pts.unsqueeze(1) - centroids.unsqueeze(0)
diffs_sq = diffs * diffs
diffs_sq_sum = torch.sum(diffs_sq, 2) # N x C
clusters = torch.argmin(diffs_sq_sum, 1) # N
mtx = torch.zeros(clusters.max()+1, N)
mtx[clusters, range_n] = 1
mtx_norm = torch.nn.functional.normalize(mtx, p=1.0, dim=1)
return torch.mm(mtx_norm, pts)
@time_repeat(5, 1)
def kmeans_pt3(pts, centroids, steps):
pts = torch.tensor(pts, dtype=torch.float32) # N x 2
centroids = torch.tensor(centroids, dtype=torch.float32) # C x 2
N = len(pts)
range_n = torch.arange(N)
for step in range(steps):
centroids = helper3(pts, centroids, range_n, N)
clusters = helper2(pts, centroids)
return [pts[clusters == i].numpy() for i in range(len(centroids))], centroids.numpy()
pt3_groups, pt3_centroids = kmeans_pt3(pts, start_centroids, num_steps)
plotter(pt3_groups, pt3_centroids, 'pt3.png')
Average run time for this implementation is: 0.0234s.
Thus JIT helps a little in this case
Finally lets try using Scikit. We see that this is very fast, as it is an optimized library for this specific task.
@time_repeat(5, 1)
def kmeans_scikit(pts, centroids, steps):
from sklearn.cluster import KMeans
kmeans = KMeans(n_clusters=len(centroids), random_state=0, init=centroids, max_iter=steps).fit(pts)
out = kmeans.predict(pts)
groups = [pts[out==i] for i in range(len(centroids))]
return groups, kmeans.cluster_centers_
scikit_groups, scikit_centroids = kmeans_scikit(pts, start_centroids, num_steps)
plotter(scikit_groups, scikit_centroids, 'scikit.png')
Average run time for this implementation is: 0.0057s
We started with a double for loop implementation, which is very slow. We next implement it in PyTorch, then optimize it to implement the group-and-average centroid computation as a matrix multiplication. Given that modern deep learning hardwares usually have a matrix multiplication accelerator, and the implementation does not have dynamic shapes, this optimization should greatly reduce times when run on such specialized hardware. However on CPU, the scikit k-means is the fastest.