Implementing the Gradient Descent Algorithm in R

[This article was first published on Environmental Science and Data Analytics, and kindly contributed to R-bloggers]. (You can report issue about the content on this page here)
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.

A Brief Introduction

Linear regression is a classic supervised statistical technique for predictive modelling which is based on the linear hypothesis:

y = mx + c

where y is the response or outcome variable, m is the gradient of the linear trend-line, x is the predictor variable and c is the intercept. The intercept is the point on the y-axis where the value of the predictor x is zero.

In order to apply the linear hypothesis to a dataset with the end aim of modelling the situation under investigation, there needs to be a linear relationship between the variables in question. A simple scatterplot is an excellent visual tool to assess linearity between two variables. Below is an example of a linear relationship between miles per gallon (mpg) and engine displacement volume (disp) of automobiles which could be modelled using linear regression. Note that there are various methods of transforming non-linear data to make them appear more linear such as log and square root transformations but we won’t discuss those here.

attach(mtcars)
plot(disp, mpg, col = "blue", pch = 20)

rplot

Now, in order to fit a good model, appropriate values for the intercept and slope must be found. R has a nice function, lm(), which creates a linear model from which we can extract the most appropriate intercept and slope (coefficients).

model <- lm(mpg ~ disp, data = mtcars)
coef(model)

(Intercept)        disp 
29.59985476 -0.04121512

We see that the intercept is set at 29.59985476 on the y-axis and that the gradient is -0.04121512. The negative gradient tells us that there is an inverse relationship between mpg and displacement with one unit increase in displacement resulting in a 0.04 unit decrease in mpg.

How are these intercept and gradient values calculated one may ask? In finding the appropriate values, the goal is to reduce the value of a statistic known as the mean squared error (MSE):

MSE = Σ(y – y_preds)² / n

 

where y represents the observed value of the response variable, y_preds represents the predicted y value from the linear model after plugging in values for the intercept and slope, and n is the number of observations in the dataset. Each set of xy data points are iterated over to find the squared error, all squared errors are summed and the sum is divided by n to get the MSE.

Using the linear model we created earlier, we can obtain y predictions and plot them on the scatterplot as a regression line. We use the predict() function in base R and plot the predicted values using abline().

y_preds <- predict(model)
abline(model)

rplot01

Next, we can calculate the MSE by summing the squared differences between observed y values and our predicted y values then dividing by the number of observations n. This gives a MSE of 9.911209 for this linear model.

errors <- unname((mpg - y_preds) ^ 2)
sum(errors) / length(mpg)

 

The Gradient Descent Algorithm

I now want to introduce the Gradient Descent Algorithm which can be used to find the optimal intercept and gradient for any set of data in which a linear relationship exists. There are various ways of calculating the intercept and gradient values but I was recently playing around with this algorithm in Python and wanted to try it out in R.

The goal of the algorithm is to find the intercept and gradient values which correspond to the lowest possible MSE. It achieves this through iteration over each set of xy data pairs whereby new intercept and gradient values are calculated as well as a new MSE. Then, the new MSE is subtracted from the old MSE and, if the difference is negligible, the optimal values are found.

The function I created below is how I implemented the gradient descent algorithm and applied it to the data we are looking at here. We pass the function our x and y variables. We also pass it the  learning rate which is the magnitude of the steps the algorithm takes along the slope of the MSE function. This can take different values but for this example it is set to 0.0000293. The convergence threshold is set to 0.001. This is the difference between the old MSE and new MSE on each iteration. Finally, we pass the value for n and set the maximum number of iterations we wish to carry out before the loop terminates.

gradientDesc <- function(x, y, learn_rate, conv_threshold, n, max_iter) {
  plot(x, y, col = "blue", pch = 20)
  m <- runif(1, 0, 1)
  c <- runif(1, 0, 1)
  yhat <- m * x + c
  MSE <- sum((y - yhat) ^ 2) / n
  converged = F
  iterations = 0
  while(converged == F) {
    ## Implement the gradient descent algorithm
    m_new <- m - learn_rate * ((1 / n) * (sum((yhat - y) * x)))
    c_new <- c - learn_rate * ((1 / n) * (sum(yhat - y)))
    m <- m_new
    c <- c_new
    yhat <- m * x + c
    MSE_new <- sum((y - yhat) ^ 2) / n
    if(MSE - MSE_new <= conv_threshold) {
      abline(c, m) 
      converged = T
      return(paste("Optimal intercept:", c, "Optimal slope:", m))
    }
    iterations = iterations + 1
    if(iterations > max_iter) { 
      abline(c, m) 
      converged = T
      return(paste("Optimal intercept:", c, "Optimal slope:", m))
    }
  }
}


# Run the function 

gradientDesc(disp, mpg, 0.0000293, 0.001, 32, 2500000)

 

The graphical output and values for the intercept and gradient returned by the function are given below:

rplot02

[1] "Optimal intercept: 29.5998514475591 Optimal slope: -0.0412151087554405"

 

The matches are good. Remember our calculated values for the intercept and gradient from earlier? Here they are again for comparison:

coef(model)

(Intercept)        disp 
29.59985476 -0.04121512

 


To leave a comment for the author, please follow the link and comment on their blog: Environmental Science and Data Analytics.

R-bloggers.com offers daily e-mail updates about R news and tutorials about learning R and many other topics. Click here if you're looking to post or find an R/data-science job.
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.

Never miss an update!
Subscribe to R-bloggers to receive
e-mails with the latest R posts.
(You will not see this message again.)

Click here to close (This popup will not appear again)