Because k-Means Kluster!
[This article was first published on Dan Thompson, and kindly contributed to R-bloggers]. (You can report issue about the content on this page here)
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.
A few months ago we started our new company, Kluster, with the name being a tip of the hat to cluster analysis. The immediate question is often:
“why do you spell kluster with a k?”
My novelty answer:
“because k-means clustering!”
In honour of this, I thought I would do a short answer to the question “what is clustering?” for the uninitiated, with a focus on k-means clustering, assisted by R.
Cluster analysis is a form of unsupervised learning, which aims to group observations into clusters.
Kernel Kluster
The central part of k-means clustering is a rather simple algorithm:
Initiate clusters
Assign observations to clusters
Adjust cluster centroids to be the mean of assigned observations
Repeat (2) and (3) until it has sufficiently converged
As always with R, there are packages which can do this for you really well. In this case, the kmeans() function in the stats package is one of many to choose from.
However, in order to understand packages better, it can be fun to quickly hack one from scratch, which is what I will do here!
Special K
There are many ways to choose the best k, which we won’t discuss here.
I am using the classic iris data set, so I will cheat and set k=3 for this example (because there are three species of irises in the data set).
Step 1 is to initiate the clusters. There are many ways of doing this – I will split my dataset into k random samples, and set the k cluster centroids as the mean locations of these points. Another option is to randomly select k points from the data as a starting point.
# Fit to three klusters
k<-3
# Randomly split the data
split_list<-split(data, sample(1:k, nrow(data), replace=T))
# Create the centroids based on the random data samples
centres<-melt(split_list,id.vars="iter") %>%
group_by(L1) %>%
summarise(
Petal.Length=mean(value[variable==”Petal.Length”]),
Petal.Width=mean(value[variable==”Petal.Width”])
)
colnames(centres)<-c("kluster","Petal.Length","Petal.Width")
I have created my own custom function to calculate the Euclidean distance between two points:
euclidean_dist<-function(x1,y1,x2,y2)
{
x_dist<-abs(x1-x2)
y_dist<-abs(y1-y2)
euc_dist<-sqrt(x_dist^2+y_dist^2)
return(euc_dist)
}
… and also a second function to assign each observation to a cluster …
assignDataPointsToNearestCentroid<-function(iter,centres,data)
{
dists<-rep(0,k)
# loop through each data point
for (i in 1:nrow(data))
{
curr_x<-data$Petal.Length[i]
curr_y<-data$Petal.Width[i]
# Get the minimum distance to each centre
for (j in 1:k)
{
klust_x<-centres$Petal.Length[j]
klust_y<-centres$Petal.Width[j]
dists[j]<-euclidean_dist(curr_x,curr_y,klust_x,klust_y)
}
closest_centre<-which.min(dists)
data$kluster[i]<-closest_centre
}
data$iter<-iter
return(data)
}
Note that looping is inefficient, but for the purpose of understanding the underlying logic of a method it can be useful.
I now use these custom functions to assign each observation to a cluster:
data<-assignDataPointsToNearestCentroid(0,centres,data)
Cool!
We now have each observation assigned to a randomly selected cluster.
Next, the exciting part. We’ll loop through iterations of updating the centroids and re-assigning the observations until we’re happy with the result. I’ve done a custom function to update the centroids (note it uses dplyr):
updateCentroids<-function(centres,data)
{
centres<-data %>%
group_by(kluster) %>%
summarise(
Petal.Length=mean(Petal.Length),
Petal.Width=mean(Petal.Width)
)
return(centres)
}
I know that this will converge quickly, so I’ll cheat and set maximum iterations to 8 for the purpose of a nice graph you’re about to see. In reality you could have a high maximum, exiting earlier if the observations stop changing clusters.
Let’s iterate:
for(i in 1:8)
{
centres<-updateCentroids(centres,data[which(data$iter==i-1),])
data<-rbind(data,assignDataPointsToNearestCentroid(i,centres,data))
}
Now, let’s have a look at how the clusters converge over each iteration:
We can see that the clusters stabilise at about the fifth iteration, and it appears that our custom cluster algorithm (roughly) might well work!
To check, let’s compare against the actual species from the iris data set. Let’s see below to get a sense of our accuracy:
Awesome. As you can see, the setosas were clustered correctly, and there is a small difference between virginica and versicolor - but on the whole, pretty cool!
Let’s quickly check the accuracy with a confusion matrix (using R’s table() function) before we finish:
Not bad - all the setosas were accurate, and only got six incorrect out of the others.
The accuracy (sum of the diagonal)/(sum of total) comes out at 96%, which I’m happy with!
There is plenty to improve on and of course don’t use this for actual clustering, however I find exercises such as this are good to properly understand the logic underlying the methods we can leverage with R’s great packages.
The full code is below:
#--------------------------------------------------
# Cheeky k-Means Klustering Kode
#--------------------------------------------------
library(ggplot2)
library(dplyr)
library(reshape2)
#--------------------------------------------------
# Declare custom function we'll need
#--------------------------------------------------
# Use the Euclidean distance.
euclidean_dist<-function(x1,y1,x2,y2)
{
x_dist<-abs(x1-x2)
y_dist<-abs(y1-y2)
euc_dist<-sqrt(x_dist^2+y_dist^2)
return(euc_dist)
}
assignDataPointsToNearestCentroid<-function(iter,centres,data)
{
dists<-rep(0,k)
# loop through each data point
for (i in 1:nrow(data))
{
curr_x<-data$Petal.Length[i]
curr_y<-data$Petal.Width[i]
# Get the minimum distance to each centre
for (j in 1:k)
{
klust_x<-centres$Petal.Length[j]
klust_y<-centres$Petal.Width[j]
dists[j]<-euclidean_dist(curr_x,curr_y,klust_x,klust_y)
}
closest_centre<-which.min(dists)
data$kluster[i]<-closest_centre
}
data$iter<-iter
return(data)
}
updateCentroids<-function(centres,data)
{
centres<-data %>%
group_by(kluster) %>%
summarise(
Petal.Length=mean(Petal.Length),
Petal.Width=mean(Petal.Width)
)
return(centres)
}
#————————————————–
# Set up data
#————————————————–
data<-data.frame(rep(0,150),rep(0,150),iris$Petal.Length,iris$Petal.Width)
colnames(data)<-c("iter","kluster","Petal.Length","Petal.Width")
#--------------------------------------------------
# INITIAL CLUSTER SETUP
#
# Set up initial cluster data points by splitting
# data into three random groups, and getting the
# centre
#--------------------------------------------------
# Fit to three klusters
k<-3
# Randomly split the data
split_list<-split(data, sample(1:k, nrow(data), replace=T))
# Create the centroids based on the random data samples
centres<-melt(split_list,id.vars="iter") %>%
group_by(L1) %>%
summarise(
Petal.Length=mean(value[variable==”Petal.Length”]),
Petal.Width=mean(value[variable==”Petal.Width”])
)
colnames(centres)<-c("kluster","Petal.Length","Petal.Width")
# Assign data to its closest centre
data<-assignDataPointsToNearestCentroid(0,centres,data)
#--------------------------------------------------
# GET CLUSTERING!
# Iterate through by:
# (1) Update the centroids to be the middle of
# the latest clusters.
# (2) Update the data points to be a member of
# cluster whose centroid is closest.
#--------------------------------------------------
# In reality, you would iterate many times or until
# the observations remain in the same clusters over
# subsequent iterations.
#
# For the sake of the ggplot which is coming,
# we will loop 8 times.
for(i in 1:8)
{
centres<-updateCentroids(centres,data[which(data$iter==i-1),])
data<-rbind(data,assignDataPointsToNearestCentroid(i,centres,data[which(data$iter==i-1),]))
}
# Plot charts to see how the iterations evolve
evolving_plots<-
ggplot(data,aes(Petal.Length,Petal.Width))+
geom_point(colour="black",size=4)+geom_point(aes(colour=factor(kluster)),size=3)+
theme(panel.background = element_rect(fill = "white"))+
facet_wrap(~iter)
# Plot the final chart
fitted_plot<-
ggplot(data[which(data$iter==8),],aes(Petal.Length,Petal.Width))+
geom_point(colour="black",size=4)+geom_point(aes(colour=factor(kluster)),size=3)+
theme(panel.background = element_rect(fill = "white"))
# Plot what the actual species are
final_plot<-
ggplot(iris,aes(Petal.Length,Petal.Width))+
geom_point(colour="black",size=4)+geom_point(aes(colour=factor(Species)),size=3)+
theme(panel.background = element_rect(fill = "white"))
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.
To leave a comment for the author, please follow the link and comment on their blog: Dan Thompson.
R-bloggers.com offers daily e-mail updates about R news and tutorials about learning R and many other topics. Click here if you're looking to post or find an R/data-science job.
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.