Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.

## Trying gradient descent for linear regression

The best way to learn an algorith is to code it. So here it is, my take on Gradient Descent Algorithm for simple linear regression.

First, we fit a simple linear model with lm for comparison with gradient descent values.

```<span class="c1">#Load libraries
</span><span class="w">
</span><span class="n">library</span><span class="p">(</span><span class="n">dplyr</span><span class="p">)</span><span class="w">
</span><span class="n">library</span><span class="p">(</span><span class="n">highcharter</span><span class="p">)</span><span class="w">

</span><span class="c1">#Scaling length variables from iris dataset.
</span><span class="w">
</span><span class="n">iris_demo</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">iris</span><span class="p">[,</span><span class="nf">c</span><span class="p">(</span><span class="s2">"Sepal.Length"</span><span class="p">,</span><span class="s2">"Petal.Length"</span><span class="p">)]</span><span class="w"> </span><span class="o">%>%</span><span class="w">
</span><span class="n">mutate</span><span class="p">(</span><span class="n">sepal_length</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="nf">as.numeric</span><span class="p">(</span><span class="n">scale</span><span class="p">(</span><span class="n">Sepal.Length</span><span class="p">)),</span><span class="w">
</span><span class="n">petal_length</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="nf">as.numeric</span><span class="p">(</span><span class="n">scale</span><span class="p">(</span><span class="n">Petal.Length</span><span class="p">)))</span><span class="w"> </span><span class="o">%>%</span><span class="w">
</span><span class="n">select</span><span class="p">(</span><span class="n">sepal_length</span><span class="p">,</span><span class="n">petal_length</span><span class="p">)</span><span class="w">

</span><span class="c1">#Fit a simple linear model to compare coefficients.
</span><span class="w">
</span><span class="n">regression</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">lm</span><span class="p">(</span><span class="n">iris_demo</span><span class="o">\$</span><span class="n">petal_length</span><span class="o">~</span><span class="n">iris_demo</span><span class="o">\$</span><span class="n">sepal_length</span><span class="p">)</span><span class="w">

</span><span class="n">coef</span><span class="p">(</span><span class="n">regression</span><span class="p">)</span><span class="w">
</span>```
```##            (Intercept) iris_demo\$sepal_length
##           4.643867e-16           8.717538e-01```
```<span class="n">iris_demo_reg</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">iris_demo</span><span class="w">

</span><span class="n">iris_demo_reg</span><span class="o">\$</span><span class="n">reg</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">predict</span><span class="p">(</span><span class="n">regression</span><span class="p">,</span><span class="n">iris_demo</span><span class="p">)</span><span class="w">

</span><span class="c1">#Plot the model with highcharter
</span><span class="w">
</span><span class="n">highchart</span><span class="p">()</span><span class="w"> </span><span class="o">%>%</span><span class="w">
</span><span class="n">hc_add_series</span><span class="p">(</span><span class="n">data</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">iris_demo_reg</span><span class="p">,</span><span class="w"> </span><span class="n">type</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">"scatter"</span><span class="p">,</span><span class="w"> </span><span class="n">hcaes</span><span class="p">(</span><span class="n">x</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">sepal_length</span><span class="p">,</span><span class="w"> </span><span class="n">y</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">petal_length</span><span class="p">),</span><span class="w"> </span><span class="n">name</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">"Sepal Length VS Petal Length"</span><span class="p">)</span><span class="w"> </span><span class="o">%>%</span><span class="w">
</span><span class="n">hc_add_series</span><span class="p">(</span><span class="n">data</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">iris_demo_reg</span><span class="p">,</span><span class="w"> </span><span class="n">type</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">"line"</span><span class="p">,</span><span class="w"> </span><span class="n">hcaes</span><span class="p">(</span><span class="n">x</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">sepal_length</span><span class="p">,</span><span class="w"> </span><span class="n">y</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">reg</span><span class="p">),</span><span class="w"> </span><span class="n">name</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">"Linear Regression"</span><span class="p">)</span><span class="w"> </span><span class="o">%>%</span><span class="w">
</span><span class="n">hc_title</span><span class="p">(</span><span class="n">text</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">"Linear Regression"</span><span class="p">)</span><span class="w">
</span>```

open

We will try to acomplish the same coefficients, this time using Gradient Descent.

```<span class="n">library</span><span class="p">(</span><span class="n">tidyr</span><span class="p">)</span><span class="w">

</span><span class="n">set.seed</span><span class="p">(</span><span class="m">135</span><span class="p">)</span><span class="w"> </span><span class="c1">#To reproduce results
</span><span class="w">

</span><span class="c1">#Auxiliary function
</span><span class="w">
</span><span class="c1"># y = mx + b
</span><span class="w">
</span><span class="n">reg</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="k">function</span><span class="p">(</span><span class="n">m</span><span class="p">,</span><span class="n">b</span><span class="p">,</span><span class="n">x</span><span class="p">)</span><span class="w">  </span><span class="nf">return</span><span class="p">(</span><span class="n">m</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="n">x</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="n">b</span><span class="p">)</span><span class="w">

</span><span class="c1">#Starting point
</span><span class="w">
</span><span class="n">b</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">runif</span><span class="p">(</span><span class="m">1</span><span class="p">)</span><span class="w">
</span><span class="n">m</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">runif</span><span class="p">(</span><span class="m">1</span><span class="p">)</span><span class="w">

</span><span class="w">
</span><span class="n">gradient_desc</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="k">function</span><span class="p">(</span><span class="n">b</span><span class="p">,</span><span class="w"> </span><span class="n">m</span><span class="p">,</span><span class="w"> </span><span class="n">data</span><span class="p">,</span><span class="w"> </span><span class="n">learning_rate</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">0.01</span><span class="p">){</span><span class="w"> </span><span class="c1"># Small steps
</span><span class="w">
</span><span class="c1"># Column names = Code easier to understand
</span><span class="w">
</span><span class="n">colnames</span><span class="p">(</span><span class="n">data</span><span class="p">)</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="nf">c</span><span class="p">(</span><span class="s2">"x"</span><span class="p">,</span><span class="s2">"y"</span><span class="p">)</span><span class="w">

</span><span class="c1">#Values for first iteration
</span><span class="w">
</span><span class="n">b_iter</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="m">0</span><span class="w">
</span><span class="n">m_iter</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="m">0</span><span class="w">
</span><span class="n">n</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">nrow</span><span class="p">(</span><span class="n">data</span><span class="p">)</span><span class="w">

</span><span class="c1"># Compute the gradient for Mean Squared Error function
</span><span class="w">
</span><span class="k">for</span><span class="p">(</span><span class="n">i</span><span class="w"> </span><span class="k">in</span><span class="w"> </span><span class="m">1</span><span class="o">:</span><span class="n">n</span><span class="p">){</span><span class="w">

</span><span class="c1"># Partial derivative for b
</span><span class="w">
</span><span class="n">b_iter</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">b_iter</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="p">(</span><span class="m">-2</span><span class="o">/</span><span class="n">n</span><span class="p">)</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="p">(</span><span class="n">data</span><span class="o">\$</span><span class="n">y</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="w"> </span><span class="o">-</span><span class="w"> </span><span class="p">((</span><span class="n">m</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="n">data</span><span class="o">\$</span><span class="n">x</span><span class="p">[</span><span class="n">i</span><span class="p">])</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="n">b</span><span class="p">))</span><span class="w">

</span><span class="c1"># Partial derivative for m
</span><span class="w">
</span><span class="n">m_iter</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">m_iter</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="p">(</span><span class="m">-2</span><span class="o">/</span><span class="n">n</span><span class="p">)</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="n">data</span><span class="o">\$</span><span class="n">x</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="p">(</span><span class="n">data</span><span class="o">\$</span><span class="n">y</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="w"> </span><span class="o">-</span><span class="w"> </span><span class="p">((</span><span class="n">m</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="n">data</span><span class="o">\$</span><span class="n">x</span><span class="p">[</span><span class="n">i</span><span class="p">])</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="n">b</span><span class="p">))</span><span class="w">

</span><span class="p">}</span><span class="w">

</span><span class="c1"># Move to the OPPOSITE direction of the derivative
</span><span class="w">
</span><span class="n">new_b</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">b</span><span class="w"> </span><span class="o">-</span><span class="w"> </span><span class="p">(</span><span class="n">learning_rate</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="n">b_iter</span><span class="p">)</span><span class="w">
</span><span class="n">new_m</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">m</span><span class="w"> </span><span class="o">-</span><span class="w"> </span><span class="p">(</span><span class="n">learning_rate</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="n">m_iter</span><span class="p">)</span><span class="w">

</span><span class="c1"># Replace values and return
</span><span class="w">
</span><span class="n">new</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="nf">list</span><span class="p">(</span><span class="n">new_b</span><span class="p">,</span><span class="n">new_m</span><span class="p">)</span><span class="w">

</span><span class="nf">return</span><span class="p">(</span><span class="n">new</span><span class="p">)</span><span class="w">

</span><span class="p">}</span><span class="w">

</span><span class="c1"># I need to store some values to make the motion plot
</span><span class="w">
</span><span class="n">vect_m</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">m</span><span class="w">
</span><span class="n">vect_b</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">b</span><span class="w">

</span><span class="c1"># Iterate to obtain better parameters
</span><span class="w">
</span><span class="k">for</span><span class="p">(</span><span class="n">i</span><span class="w"> </span><span class="k">in</span><span class="w"> </span><span class="m">1</span><span class="o">:</span><span class="m">1000</span><span class="p">){</span><span class="w">
</span><span class="k">if</span><span class="p">(</span><span class="n">i</span><span class="w"> </span><span class="o">%in%</span><span class="w"> </span><span class="nf">c</span><span class="p">(</span><span class="m">1</span><span class="p">,</span><span class="m">100</span><span class="p">,</span><span class="m">250</span><span class="p">,</span><span class="m">500</span><span class="p">)){</span><span class="w"> </span><span class="c1"># I keep some values in the iteration for the plot
</span><span class="w">    </span><span class="n">vect_m</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="nf">c</span><span class="p">(</span><span class="n">vect_m</span><span class="p">,</span><span class="n">m</span><span class="p">)</span><span class="w">
</span><span class="n">vect_b</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="nf">c</span><span class="p">(</span><span class="n">vect_b</span><span class="p">,</span><span class="n">b</span><span class="p">)</span><span class="w">
</span><span class="p">}</span><span class="w">
</span><span class="n">x</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">gradient_desc</span><span class="p">(</span><span class="n">b</span><span class="p">,</span><span class="n">m</span><span class="p">,</span><span class="n">iris_demo</span><span class="p">)</span><span class="w">
</span><span class="n">b</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">x</span><span class="p">[[</span><span class="m">1</span><span class="p">]]</span><span class="w">
</span><span class="n">m</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">x</span><span class="p">[[</span><span class="m">2</span><span class="p">]]</span><span class="w">
</span><span class="p">}</span><span class="w">

</span><span class="n">print</span><span class="p">(</span><span class="n">paste0</span><span class="p">(</span><span class="s2">"m = "</span><span class="p">,</span><span class="w"> </span><span class="n">m</span><span class="p">))</span><span class="w">
</span>```
`## [1] "m = 0.871753774273602"`
```<span class="n">print</span><span class="p">(</span><span class="n">paste0</span><span class="p">(</span><span class="s2">"b = "</span><span class="p">,</span><span class="w"> </span><span class="n">b</span><span class="p">))</span><span class="w">
</span>```
`## [1] "b = 5.52239677041512e-10"`

The difference in the coefficients is minimal.

We can see how the iterations work in the next plot:

```<span class="c1">#Compute new values
</span><span class="w">
</span><span class="n">iris_demo</span><span class="o">\$</span><span class="n">preit</span><span class="w">    </span><span class="o"><-</span><span class="w"> </span><span class="n">reg</span><span class="p">(</span><span class="n">vect_m</span><span class="p">[</span><span class="m">1</span><span class="p">],</span><span class="n">vect_b</span><span class="p">[</span><span class="m">1</span><span class="p">],</span><span class="n">iris_demo</span><span class="o">\$</span><span class="n">sepal_length</span><span class="p">)</span><span class="w">
</span><span class="n">iris_demo</span><span class="o">\$</span><span class="n">it1</span><span class="w">      </span><span class="o"><-</span><span class="w"> </span><span class="n">reg</span><span class="p">(</span><span class="n">vect_m</span><span class="p">[</span><span class="m">2</span><span class="p">],</span><span class="n">vect_b</span><span class="p">[</span><span class="m">2</span><span class="p">],</span><span class="n">iris_demo</span><span class="o">\$</span><span class="n">sepal_length</span><span class="p">)</span><span class="w">
</span><span class="n">iris_demo</span><span class="o">\$</span><span class="n">it100</span><span class="w">    </span><span class="o"><-</span><span class="w"> </span><span class="n">reg</span><span class="p">(</span><span class="n">vect_m</span><span class="p">[</span><span class="m">3</span><span class="p">],</span><span class="n">vect_b</span><span class="p">[</span><span class="m">3</span><span class="p">],</span><span class="n">iris_demo</span><span class="o">\$</span><span class="n">sepal_length</span><span class="p">)</span><span class="w">
</span><span class="n">iris_demo</span><span class="o">\$</span><span class="n">it250</span><span class="w">    </span><span class="o"><-</span><span class="w"> </span><span class="n">reg</span><span class="p">(</span><span class="n">vect_m</span><span class="p">[</span><span class="m">4</span><span class="p">],</span><span class="n">vect_b</span><span class="p">[</span><span class="m">4</span><span class="p">],</span><span class="n">iris_demo</span><span class="o">\$</span><span class="n">sepal_length</span><span class="p">)</span><span class="w">
</span><span class="n">iris_demo</span><span class="o">\$</span><span class="n">it500</span><span class="w">    </span><span class="o"><-</span><span class="w"> </span><span class="n">reg</span><span class="p">(</span><span class="n">vect_m</span><span class="p">[</span><span class="m">5</span><span class="p">],</span><span class="n">vect_b</span><span class="p">[</span><span class="m">5</span><span class="p">],</span><span class="n">iris_demo</span><span class="o">\$</span><span class="n">sepal_length</span><span class="p">)</span><span class="w">
</span><span class="n">iris_demo</span><span class="o">\$</span><span class="n">finalit</span><span class="w">  </span><span class="o"><-</span><span class="w"> </span><span class="n">reg</span><span class="p">(</span><span class="n">m</span><span class="p">,</span><span class="n">b</span><span class="p">,</span><span class="n">iris_demo</span><span class="o">\$</span><span class="n">sepal_length</span><span class="p">)</span><span class="w">

</span><span class="n">iris_gathered</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">iris_demo</span><span class="w"> </span><span class="o">%>%</span><span class="w"> </span><span class="n">gather</span><span class="p">(</span><span class="n">key</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">gr</span><span class="p">,</span><span class="w"> </span><span class="n">value</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">val</span><span class="p">,</span><span class="w"> </span><span class="n">preit</span><span class="o">:</span><span class="n">finalit</span><span class="p">)</span><span class="w"> </span><span class="o">%>%</span><span class="w">
</span><span class="n">select</span><span class="p">(</span><span class="o">-</span><span class="n">petal_length</span><span class="p">)</span><span class="w"> </span><span class="o">%>%</span><span class="w">
</span><span class="n">distinct</span><span class="p">()</span><span class="w">

</span><span class="n">iris_start</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">iris_gathered</span><span class="w"> </span><span class="o">%>%</span><span class="w">
</span><span class="n">filter</span><span class="p">(</span><span class="n">gr</span><span class="w"> </span><span class="o">==</span><span class="w"> </span><span class="s2">"preit"</span><span class="p">)</span><span class="w">

</span><span class="n">iris_seq</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">iris_gathered</span><span class="w"> </span><span class="o">%>%</span><span class="w">
</span><span class="n">group_by</span><span class="p">(</span><span class="n">sepal_length</span><span class="p">)</span><span class="w"> </span><span class="o">%>%</span><span class="w">
</span><span class="n">do</span><span class="p">(</span><span class="n">sequence</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">list_parse</span><span class="p">(</span><span class="n">select</span><span class="p">(</span><span class="n">.</span><span class="p">,</span><span class="w"> </span><span class="n">y</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">val</span><span class="p">)))</span><span class="w">

</span><span class="n">iris_data</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">left_join</span><span class="p">(</span><span class="n">iris_start</span><span class="p">,</span><span class="w"> </span><span class="n">iris_seq</span><span class="p">)</span><span class="w">

</span><span class="c1">#Motion Plot
</span><span class="w">
</span><span class="n">irhc2</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">highchart</span><span class="p">()</span><span class="w"> </span><span class="o">%>%</span><span class="w">
</span><span class="n">hc_add_series</span><span class="p">(</span><span class="n">data</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">iris_data</span><span class="p">,</span><span class="w"> </span><span class="n">type</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">"line"</span><span class="p">,</span><span class="w"> </span><span class="n">hcaes</span><span class="p">(</span><span class="n">x</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">sepal_length</span><span class="p">,</span><span class="w"> </span><span class="n">y</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">val</span><span class="p">),</span><span class="w"> </span><span class="n">name</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">"Gradient Descent"</span><span class="p">)</span><span class="w"> </span><span class="o">%>%</span><span class="w">
</span><span class="n">hc_motion</span><span class="p">(</span><span class="n">enabled</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="kc">TRUE</span><span class="p">,</span><span class="w"> </span><span class="n">series</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">0</span><span class="p">,</span><span class="w"> </span><span class="n">startIndex</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">0</span><span class="p">,</span><span class="w">
</span><span class="n">labels</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="nf">c</span><span class="p">(</span><span class="s2">"Iteration 1"</span><span class="p">,</span><span class="s2">"Iteration 100"</span><span class="p">,</span><span class="s2">"Iteration 250"</span><span class="p">,</span><span class="s2">"Iteration 500"</span><span class="p">,</span><span class="s2">"Final Iteration"</span><span class="p">))</span><span class="w"> </span><span class="o">%>%</span><span class="w">
</span><span class="n">hc_add_series</span><span class="p">(</span><span class="n">data</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">iris_demo_reg</span><span class="p">,</span><span class="w"> </span><span class="n">type</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">"scatter"</span><span class="p">,</span><span class="w"> </span><span class="n">hcaes</span><span class="p">(</span><span class="n">x</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">sepal_length</span><span class="p">,</span><span class="w"> </span><span class="n">y</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">petal_length</span><span class="p">),</span><span class="w"> </span><span class="n">name</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">"Sepal Length VS Petal Length"</span><span class="p">)</span><span class="w"> </span><span class="o">%>%</span><span class="w">
</span><span class="n">hc_title</span><span class="p">(</span><span class="n">text</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">"Gradient Descent Iterations"</span><span class="p">)</span><span class="w">

</span><span class="n">irhc2</span><span class="w">
</span>```

open

Maybe in a future post we can try a multivariate regression model!