k-means as image compression

We learned about the k-means clustering at a class I was taking. The teacher mentioned it could be used as a way to determine the most important colors in an image. Decided to try and implement that in python.

from pylab import imread
import numpy as np
import matplotlib.pyplot as plt

def kmeans(X, k, <max_iters=20):
    # sample k points from X and use these as the initial means
    means = X[np.random.randint(X.shape[0], size=k), :]

    for _ in range(max_iters):
        Xh, Xw = X.shape
        m = np.tile(means, (Xh, 1)).reshape(Xh, k, Xw)
        x = np.tile(X, k).reshape(Xh, k, Xw)
        d = x - m
        dsq = np.sum(d * d, axis=2)
        i = np.argmin(dsq, axis=1)
        means = [np.mean(X[i == cl], axis=0) for cl in np.unique(i)]
        k = len(means)

    return i, means

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16,8))

# try k means as an image color reduction algorithm
img = imread('img3_low.jpg')
#img = img[100:160, 100:160]
img = np.asarray(img)
ax1.imshow(img, interpolation="nearest")
ax1.axis("off")

img_arr = img.reshape((img.shape[0] * img.shape[1], 3))

# do k means clustering
y, means = kmeans(img_arr, 7, max_iters=10)

means = np.array(means).astype(np.uint8)
img_comp = means[y].reshape(img.shape)
ax2.imshow(img_comp, interpolation="nearest")
ax2.axis("off")

plt.show()

Looking at the results, I'm pretty sure .gif uses a similar method of image compression.

example 1 example 2 example 3

Photos by Oliver Sjöström, Eberhard Grossgasteiger and Guilherme Rossi from Pexels.