Restricted Boltzmann Machines in R

January 14, 2013

(This article was first published on Statistically Significant, and kindly contributed to R-bloggers)

Restricted Boltzmann Machines (RBMs) are an unsupervised learning method (like principal components). An RBM is a probabilistic and undirected graphical model. They are becoming more popular in machine learning due to recent success in training them with contrastive divergence. They have been proven useful in collaborative filtering, being one of the most successful methods in the Netflix challenge (paper). Furthermore, they have been tantamount to training deep learning models, which appear to be the best current models for image and digit recognition.

I do not want to go into too much detail, but I would like to give a high level overview of RBMs. Edwin Chen has an introduction that is much better. The usual version of an RBM is a probabilistic model for binary vectors. An image can be represented as a binary vector if each pixel that is dark enough is represented as a 1 and the non-dark pixels are 0's. In addition to the visible binary vector, it is assumed that there is a hidden binary vector, and each element of the hidden unit is connected to each unit of the visible unit by symmetric weights. The weights are represented by the matrix W, where the i,jth component is the weight between hidden unit i and visible unit j. It is important that there are no connections between hidden units or between visible units. The probability that visible unit j is 1 is the inverse logistic function of the sum of the weights connected to visible unit j, in which the hidden units are 1. In math notation (where σ is the inverse logistic, or sigmoid, function):


The weights are symmetric, so

As you can see, the model is defined by its conditional probabilities. The task is to find the weight matrix W that best explains the visible data for a given number of hidden units.

I have been taking Geoff Hinton's coursera course on neural networks, which I would recommend to anyone interested. One of the programming assignments was to fill in code to write an RBM in Matlab/Octave. I have since tried to find a version for R, but have not had any luck, so I decided to translate the code myself. Below is the code to calculate the weight matrix. It is fairly simple and I only use contrastive divergence 1. The input data is p×n instead of the usual transpose.

# rbm_w is a matrix of size <number of hidden units> by <number of visible units>
# visible_state is matrix of size <number of visible units> by <number of data cases>
# hidden_state is a binary matrix of size <number of hidden units> by <number of data cases>
visible_state_to_hidden_probabilities <- function(rbm_w, visible_state) {
1/(1+exp(-rbm_w %*% visible_state))
hidden_state_to_visible_probabilities <- function(rbm_w, hidden_state) {
1/(1+exp(-t(rbm_w) %*% hidden_state))
configuration_goodness <- function(rbm_w, visible_state, hidden_state) {
for (i in 1:dim(visible_state)[2]) {
out=out+t(hidden_state[,i]) %*% rbm_w %*% visible_state[,i]
configuration_goodness_gradient <- function(visible_state, hidden_state) {
hidden_state %*% t(visible_state)/dim(visible_state)[2]
sample_bernoulli <- function(mat) {
cd1 <- function(rbm_w, visible_data) {
visible_data = sample_bernoulli(visible_data)
H0=sample_bernoulli(visible_state_to_hidden_probabilities(rbm_w, visible_data))
vh0=configuration_goodness_gradient(visible_data, H0)
V1=sample_bernoulli(hidden_state_to_visible_probabilities(rbm_w, H0))
H1=visible_state_to_hidden_probabilities(rbm_w, V1)
vh1=configuration_goodness_gradient(V1, H1)
rbm <- function(num_hidden, training_data, learning_rate, n_iterations, mini_batch_size=100, momentum=0.9, quiet=FALSE) {
# This trains a model that's defined by a single matrix of weights.
# <num_hidden> is the number of hidden units
# cd1 is a function that takes parameters <model> and <data> and returns the gradient (or approximate gradient in the case of CD-1) of the function that we're maximizing. Note the contrast with the loss function that we saw in PA3, which we were minimizing. The returned gradient is an array of the same shape as the provided <model> parameter.
# This uses mini-batches no weight decay and no early stopping.
# This returns the matrix of weights of the trained model.
if (n %% mini_batch_size != 0) {
stop("the number of test cases must be divisable by the mini_batch_size")
model = (matrix(runif(num_hidden*p),num_hidden,p) * 2 - 1) * 0.1
momentum_speed = matrix(0,num_hidden,p)
start_of_next_mini_batch = 1;
for (iteration_number in 1:n_iterations) {
if (!quiet) {cat("Iter",iteration_number,"\n")}
mini_batch = training_data[, start_of_next_mini_batch:(start_of_next_mini_batch + mini_batch_size - 1)]
start_of_next_mini_batch = (start_of_next_mini_batch + mini_batch_size) %% n
gradient = cd1(model, mini_batch)
momentum_speed = momentum * momentum_speed + gradient
model = model + momentum_speed * learning_rate
Created by Pretty R at

I loaded the hand written digit data that was given in the class. To train the RBM, use the syntax below.

weights=rbm(num_hidden=30, training_data=train, learning_rate=.09, n_iterations=5000,
mini_batch_size=100, momentum=0.9)
Created by Pretty R at

After training the weights, I can visualize them. Below is a plot of strength of the weights going to each pixel. Each facet displays the weights going to/coming from one of the hidden units. I trained 30 units so that it would be easy to show them all on one plot. You can see that most of the hidden units correspond strongly to one digit or another. I think this means it is finding the joint structure of all of the pixels and that pixels representing those numbers are darkened together often.

mw=melt(weights); mw$Var3=floor((mw$Var2-1)/16)+1; mw$Var2=(mw$Var2-1)%%16 + 1; mw$Var3=17-mw$Var3;
labs(x=NULL,y=NULL,title="Visualization of Weights")+
Created by Pretty R at

To leave a comment for the author, please follow the link and comment on his blog: Statistically Significant. offers daily e-mail updates about R news and tutorials on topics such as: visualization (ggplot2, Boxplots, maps, animation), programming (RStudio, Sweave, LaTeX, SQL, Eclipse, git, hadoop, Web Scraping) statistics (regression, PCA, time series, trading) and more...

If you got this far, why not subscribe for updates from the site? Choose your flavor: e-mail, twitter, RSS, or facebook...

Comments are closed.