Machine Learning Explained: Kmeans

[This article was first published on Enhance Data Science, and kindly contributed to R-bloggers]. (You can report issue about the content on this page here)
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.

Kmeans is one of the most popular and simple algorithm to discover underlying structures in your data. The goal of kmeans is simple, split your data in k different groups represented by their mean. The mean of each group is assumed to be a good summary of each observation of this cluster.

Kmeans Algorithm

We assume that we want to split the data into k groups, so we need to find and assign k centers. How to define and find these centers?

They are the solution to the equation:   min \sum_{i}^N \sum_{j}^K O_i^j||x_i-\mu_j ||^2  where  O_i^j = 1    if the observation i is assigned to the center j and 0 otherwise.

 

Basically, this equation means that we are looking for the k centers which will minimize the distance between them and the points of their cluster. This is an optimization problem, but since the function, we want to minimize is not convex and some variables are binary, it cannot be solved in classic ways with gradient descent.
The usual way to solve it is the following:

  1. Initialize randomly the centers by selecting k-observations
  2. While some convergence criterion is not met
    1. Assign each observation to its closest center
    2. Update each center. The new centers are the mean of the observation of each group.
    3. Update the convergence criterion.

K-means from scratch with R

Now that we have the algorithm in pseudocode, let’s implement kmeans from scratch in R. First,we’ll create some toys data based on five 2D gaussian distributions.

require(MASS)
require(ggplot2)
set.seed(1234)
set1=mvrnorm(n = 300, c(-4,10), matrix(c(1.5,1,1,1.5),2))
set2=mvrnorm(n = 300, c(5,7), matrix(c(1,2,2,6),2))
set3=mvrnorm(n = 300, c(-1,1), matrix(c(4,0,0,4),2))
set4=mvrnorm(n = 300, c(10,-10), matrix(c(4,0,0,4),2))
set5=mvrnorm(n = 300, c(3,-3), matrix(c(4,0,0,4),2))
DF=data.frame(rbind(set1,set2,set3,set4,set5),cluster=as.factor(c(rep(1:5,each=300))))
ggplot(DF,aes(x=X1,y=X2,color=cluster))+geom_point()

On this dataset, Kmeans will work well since each distribution has a circular shape. Here are what the data look like:
Cluster kmeans dataset1

Now that we have a dataset, let’s inplement kmeans.

Initialisation of the centroids

The initialisation of the centroids is crucial and will change how the algorithm behave. Here, we will wimply takes K random points from the data.

#Centroids initialisation
  centroids=data[sample.int(nrow(data),K),]
##Stopping criteria initilisation. 
  current_stop_crit=10e10
##Vector where the assigned centers of each points will be saved
  cluster=rep(0,nrow(data))
##Has the alogrithm converged ?
  converged=F
  it=1

Assigning points to their clusters

At each iteration, every points will be assigned to its closest cluster. To do so, the euclidian distance between each points and each centers is computed, the lowest distance and the center for which it’s reached is saved.

###Iterating over observations
    for (i in 1:nrow(data))
    {
##Setting a high minimum distance
      min_dist=10e10
##Iterating over centroids
      for (centroid in 1:nrow(centroids))
      {
##Computing the L2 distance
        distance_to_centroid=sum((centroids[centroid,]-data[i,])^2)
##This centroid is the closest centroid to the point
        if (distance_to_centroid<=min_dist)
        {
##The point is assigned to this centroid/cluster
          cluster[i]=centroid
          min_dist=distance_to_centroid
        }
      }
    }

Centroids update

Once each point has been assigned to the closest centroids, the coordinates of each centroid are updated. The new coordinates are the means of the observations which belongs to the cluster.

##For each centroid
    for (i in 1:nrow(centroids))
    {
##The new coordinates are the means of the point in the cluster
      centroids[i,]=apply(data[cluster==i,],2,mean)
    }

Stopping criterion

We do not want the algorithm to run indefinetely, hence we need a stopping criterion to stop the algorthm when we are close enough to a minimum. The criterion is simply that when the centroids stop moving, the algorithm should stop.

  while(current_stop_crit>=stop_crit & converged==F)
  {
    it=it+1
    if (current_stop_crit<=stop_crit)
    {
      converged=T
    }
    old_centroids=centroids
###Run previous step

####Recompute stop criterion
current_stop_crit=mean((old_centroids-centroids)^2)

Complete function

kmeans=function(data,K=4,stop_crit=10e-5)
{
  #Initialisation of clusters
  centroids=data[sample.int(nrow(data),K),]
  current_stop_crit=1000
  cluster=rep(0,nrow(data))
  converged=F
  it=1
  while(current_stop_crit>=stop_crit & converged==F)
  {
    it=it+1
    if (current_stop_crit<=stop_crit)
    {
      converged=T
    }
    old_centroids=centroids
    ##Assigning each point to a centroid
    for (i in 1:nrow(data))
    {
      min_dist=10e10
      for (centroid in 1:nrow(centroids))
      {
        distance_to_centroid=sum((centroids[centroid,]-data[i,])^2)
        if (distance_to_centroid<=min_dist)
        {
          cluster[i]=centroid
          min_dist=distance_to_centroid
        }
      }
    }
    ##Assigning each point to a centroid
    for (i in 1:nrow(centroids))
    {
      centroids[i,]=apply(data[cluster==i,],2,mean)
    }
    current_stop_crit=mean((old_centroids-centroids)^2)
  }
  return(list(data=data.frame(data,cluster),centroids=centroids))
}

You can easily run the code to see your clusters:

res=kmeans(DF[1:2],K=5)
res$centroids$cluster=1:5
res$data$isCentroid=F
res$centroids$isCentroid=T
data_plot=rbind(res$centroids,res$data)
ggplot(data_plot,aes(x=X1,y=X2,color=as.factor(cluster),size=isCentroid,alpha=isCentroid))+geom_point()

Exploring kmeans results

Now let’s try the algorithm on two different datasets. First, on the 5 Gaussians distributions:
Evolution of centroids position in kmeans
The centroids move and split the data in clusters which are very close to the original ones. Kmeas is doing a great job here.
Now, instead of having nice gaussian distributions, we will build three rings on into another.

##Building three sets
set1=data.frame(r=runif(300,0.1,0.5),theta=runif(300,0,360),set='1')
set2=data.frame(r=runif(300,1,1.5),theta=runif(300,0,360),set='2')
set3=data.frame(r=runif(300,3,5),theta=runif(300,0,360),set='3')
##Transformation in rings
data_2=rbind(set1,set2,set3)
data_2$x=data_2$r*cos(2*3.14*data_2$theta)
data_2$y=(data_2$r)*sin(2*3.14*data_2$theta)

The kmeans is performing very poorly on these new data. Actually, the euclidian distance is not adapted to this kind of problem, since the data are not in a circular shape.

So before using kmeans, you should ensure that the data is in appropriate shapes, if not, you can apply transformations or change the distance you are using in the kmeans.

The importance of a good initialisation

The kmeans algorithm only looks for a local mimimum which is often not a global optimum. Hence, different initialisation can lead to very different results.
Importance of initialisation for kmeans
We ran the kmean algorithm over more than 60 different starting positions.As you can see, sometimes, the algorithm results in poor centroids due to an unlucky initialization. The solution to this is simply to run kmeans several times and to take the best centroids set. The quality of initialization can also be improved with kmeans ++, the algorithm selects starting points which are less likely to perform poorly.

Want to learn more on Machine Learning ? Here is a selection of Machine Learning Explained posts:
Dimensionality reduction
Supervised vs unsupervised vs reinforcement learning
Regularization in machine learning

The post Machine Learning Explained: Kmeans appeared first on Enhance Data Science.

To leave a comment for the author, please follow the link and comment on their blog: Enhance Data Science.

R-bloggers.com offers daily e-mail updates about R news and tutorials about learning R and many other topics. Click here if you're looking to post or find an R/data-science job.
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.

Never miss an update!
Subscribe to R-bloggers to receive
e-mails with the latest R posts.
(You will not see this message again.)

Click here to close (This popup will not appear again)