Site icon R-bloggers

RNN made easy with MXNet R

[This article was first published on DMLC, and kindly contributed to R-bloggers]. (You can report issue about the content on this page here)
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.

This tutorial presents an example of application of RNN to text classification using padded and bucketed data to efficiently handle sequences of varying lengths. Some functionalities require running on a GPU with CUDA.

Example based on sentiment analysis on the IMDB data.

What’s special about sequence modeling?

Whether working with times series or text at the character or word level, modeling sequences typically involves dealing with samples of varying length.

To efficiently feed the Recurrent Neural Network (RNN) with samples of even length within each batch, two tricks can be used:

Non numeric features such as words need to be transformed into a numeric representation. This task is commonly performed by the embedding operator which first requires to convert words into a 0 based index. The embedding will map a vector of features based on that index. In the example below, the embedding projects each word into 2 new numeric features.

Data preparation

For this demo, the data preparation is performed by the script data_preprocessing_seq_to_one.R which involves the following steps:

To illustrate the benefit of bucketing, two datasets are created:

Below is the example of the assignation of the bucketed data and labels into mx.io.bucket.iter iterator. This iterator behaves essentially the same as the mx.io.arrayiter except that is pushes samples coming from the different buckets along with a bucketID to identify the appropriate network to use.

corpus_bucketed_train <- readRDS(file = "data/corpus_bucketed_train.rds")
corpus_bucketed_test <- readRDS(file = "data/corpus_bucketed_test.rds")

vocab <- length(corpus_bucketed_test$dic)

### Create iterators
batch.size = 64

train.data.bucket <- mx.io.bucket.iter(buckets = corpus_bucketed_train$buckets, 
                                batch.size = batch.size, 
                                data.mask.element = 0, shuffle = TRUE)

eval.data.bucket <- mx.io.bucket.iter(buckets = corpus_bucketed_test$buckets, 
                                batch.size = batch.size, 
                                data.mask.element = 0, shuffle = FALSE)

Define the architecture

Below are the graph representations of a seq-to-one architecture with LSTM cells. Note that input data is of shape batch.size X seq.length while the output of the RNN operator is of shape hidden.features X batch.size X seq.length.

For bucketing, a list of symbols is defined, one for each bucket length. During training, at each batch the appropriate symbol is bound according to the bucketID provided by the iterator.

symbol_single <- rnn.graph(config = "seq-to-one", cell.type = "lstm", 
                           num.rnn.layer = 1, num.embed = 2, num.hidden = 4, 
                           num.decode = 2, input.size = vocab, dropout = 0.5, 
                           ignore_label = -1, loss_output = "softmax",
                           output_last_state = F, masking = T)
bucket_list <- unique(c(train.data.bucket$bucket.names, eval.data.bucket$bucket.names))

symbol_buckets <- sapply(bucket_list, function(seq) {
  rnn.graph(config = "seq-to-one", cell.type = "lstm", 
                   num.rnn.layer = 1, num.embed = 2, num.hidden = 4, 
                   num.decode = 2, input.size = vocab, dropout = 0.5, 
                   ignore_label = -1, loss_output = "softmax",
                   output_last_state = F, masking = T)})

graph.viz(symbol_single, type = "graph", direction = "LR", 
          graph.height.px = 50, graph.width.px = 800, shape=c(64, 5))

The representation of an unrolled RNN typically assumes a fixed length sequence. The operator mx.symbol.RNN simplifies the process by abstracting the recurrent cells into a single operator that accepts batches of varying length (each batch contains sequences of identical length).

Train the model

First the non bucketed model is trained for 6 epochs:

devices <- mx.gpu(0)

initializer <- mx.init.Xavier(rnd_type = "gaussian", factor_type = "avg", magnitude = 2.5)

optimizer <- mx.opt.create("rmsprop", learning.rate = 1e-3, gamma1 = 0.95, gamma2 = 0.92, 
                           wd = 1e-4, clip_gradient = 5, rescale.grad=1/batch.size)

logger <- mx.metric.logger()
epoch.end.callback <- mx.callback.log.train.metric(period = 1, logger = logger)
batch.end.callback <- mx.callback.log.train.metric(period = 50)

system.time(
  model <- mx.model.buckets(symbol = symbol_single,
                            train.data = train.data.single, eval.data = eval.data.single,
                            num.round = 5, ctx = devices, verbose = FALSE,
                            metric = mx.metric.accuracy, optimizer = optimizer,  
                            initializer = initializer,
                            batch.end.callback = NULL, 
                            epoch.end.callback = epoch.end.callback)
)
##    user  system elapsed 
## 205.214  17.253 210.265

Then training with the bucketing trick. Note that no additional effort is required: just need to provide a list of symbols rather than a single one and have an iterator pushing samples from the different buckets.

devices <- mx.gpu(0)

initializer <- mx.init.Xavier(rnd_type = "gaussian", factor_type = "avg", magnitude = 2.5)

optimizer <- mx.opt.create("rmsprop", learning.rate = 1e-3, gamma1 = 0.95, gamma2 = 0.92, 
                           wd = 1e-4, clip_gradient = 5, rescale.grad=1/batch.size)

logger <- mx.metric.logger()
epoch.end.callback <- mx.callback.log.train.metric(period = 1, logger = logger)
batch.end.callback <- mx.callback.log.train.metric(period = 50)

system.time(
  model <- mx.model.buckets(symbol = symbol_buckets,
                            train.data = train.data.bucket, eval.data = eval.data.bucket,
                            num.round = 5, ctx = devices, verbose = FALSE,
                            metric = mx.metric.accuracy, optimizer = optimizer,  
                            initializer = initializer,
                            batch.end.callback = NULL, 
                            epoch.end.callback = epoch.end.callback)
)
##    user  system elapsed 
## 129.578  11.500 125.120

The speedup is substantial, around 125 sec. instead of 210 sec., a 40% speedup with little effort.

Plot word embeddings

Word representation can be visualized by looking at the assigned weights in any of the embedding dimensions. Here, we look simultaneously at the two embeddings learnt in the LSTM model.

Since the model attempts to predict the sentiment, it’s no surprise that the 2 dimensions into which each word is projected appear correlated with words’ polarity. Positive words are associated with lower X1 values (“great”, “excellent”), while the most negative words appear at the far right (“terrible”, “worst”). By representing words of similar meaning with features of values, embedding much facilitates the remaining classification task for the network.

Inference on test data

The utility function mx.infer.buckets has been added to simplify inference on RNN with bucketed data.

ctx <- mx.gpu(0)
batch.size <- 64

corpus_bucketed_test <- readRDS(file = "data/corpus_bucketed_test.rds")

test.data <- mx.io.bucket.iter(buckets = corpus_bucketed_test$buckets, 
                              batch.size = batch.size, 
                              data.mask.element = 0, shuffle = FALSE)
infer <- mx.infer.buckets(infer.data = test.data, model = model, ctx = ctx)

pred_raw <- t(as.array(infer))
pred <- max.col(pred_raw, tie = "first") - 1
label <- unlist(lapply(corpus_bucketed_test$buckets, function(x) x$label))

acc <- sum(label == pred)/length(label)
roc <- roc(predictions = pred_raw[, 2], labels = factor(label))
auc <- auc(roc)

Accuracy: 87.6%

AUC: 0.9436

To leave a comment for the author, please follow the link and comment on their blog: DMLC.

R-bloggers.com offers daily e-mail updates about R news and tutorials about learning R and many other topics. Click here if you're looking to post or find an R/data-science job.
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.