Recurrent Models and Examples with MXNetR

[This article was first published on DMLC, and kindly contributed to R-bloggers]. (You can report issue about the content on this page here)
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.

As a new lightweight and flexible deep learning platform, MXNet provides a portable backend, which can be called from R side. MXNetR is an R package that provide R users with fast GPU computation and state-of-art deep learning models.

In this post, We have provided several high-level APIs for recurrent models with MXNetR. Recurrent neural network (RNN) is a class of artificial neural networks, which is very popular in the sequence labelling tasks, such as handwriting recognition, speech recognition.

We will introduce our implementation of the recurrent models including RNN, LSTM and GRU. In addition, the examples such as char-rnn will be showed to explain how to use the RNN models. By the way, several optimizers are added in MXNetR too.

This post demonstrates the implementation of the recurrent modules, including the structure of recurrent cells, unrolling of RNN models and the specific functions for training and testing.

  1. Three kinds of recurrent cells, including custom RNN, LSTM and GRU cells.

  2. How to unroll RNN models to common feedward network.

  3. How to train RNN models, including setting up and training RNN models by our specific solver using low-level simple-bind interface.

  4. How to utilize the trained RNN models using RNN inference interfaces, including inference and forward functions.

The link to the commits is https://github.com/dmlc/mxnet/commits?author=ziyeqinghan.

Recurrent Models

Since the RNN model can be treated as a deep feedforward neural network, which unfolds in time, it suffers from the problem of vanishing and exploding gradients. Thus, there are several variants of RNN to learn the long term dependency, including Long Short-Term Memory (LSTM) [1] and Gated Recurrent Unit (GRU) [2].

We will introduce three RNN models including the custom RNN, LSTM and GRU which has been implemented in MXNetR. To see the complete code, please refer to the relevant files rnn.R, lstm.R, gru.R and rnn_model.R in the R-package/R directory respectively.

RNN Cells

The main difference between three RNN models is that they have corresponding cells with different structures to mitigate the problem of vanishing and exploding gradients.

Custom RNN Cells

The Common RNN can be considered as feedforward networks with self-connected hidden layers. As Figure 1 shows, the key of the RNN is that it allows the previous inputs has influence on the current output by using the recurrent connection.

Given an input sequence x and the previous state h, a custom RNN cell produces the next states successively. Thus, there are two types of connections, the input to the hidden i2h, and the (previous) hidden to the hidden h2h. Then an optional batch normalization layer and nonlinear activation layer (e.g. tanh) are followed to generate the output states.

<span class="n">rnn</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="k">function</span><span class="p">(</span><span class="n">num.hidden</span><span class="p">,</span><span class="w"> </span><span class="n">indata</span><span class="p">,</span><span class="w"> </span><span class="n">prev.state</span><span class="p">,</span><span class="w"> </span><span class="n">param</span><span class="p">,</span><span class="w"> </span><span class="n">seqidx</span><span class="p">,</span><span class="w"> 
                </span><span class="n">layeridx</span><span class="p">,</span><span class="w"> </span><span class="n">dropout</span><span class="o">=</span><span class="m">0</span><span class="n">.</span><span class="p">,</span><span class="w"> </span><span class="n">batch.norm</span><span class="o">=</span><span class="kc">FALSE</span><span class="p">)</span><span class="w"> </span><span class="p">{</span><span class="w">
    </span><span class="k">if</span><span class="w"> </span><span class="p">(</span><span class="n">dropout</span><span class="w"> </span><span class="o">></span><span class="w"> </span><span class="m">0</span><span class="n">.</span><span class="w"> </span><span class="p">)</span><span class="w">
        </span><span class="n">indata</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">mx.symbol.Dropout</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="n">indata</span><span class="p">,</span><span class="w"> </span><span class="n">p</span><span class="o">=</span><span class="n">dropout</span><span class="p">)</span><span class="w">
    </span><span class="n">i</span><span class="m">2</span><span class="n">h</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">mx.symbol.FullyConnected</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="n">indata</span><span class="p">,</span><span class="w">
                                    </span><span class="n">weight</span><span class="o">=</span><span class="n">param</span><span class="o">$</span><span class="n">i</span><span class="m">2</span><span class="n">h.weight</span><span class="p">,</span><span class="w">
                                    </span><span class="n">bias</span><span class="o">=</span><span class="n">param</span><span class="o">$</span><span class="n">i</span><span class="m">2</span><span class="n">h.bias</span><span class="p">,</span><span class="w">
                                    </span><span class="n">num.hidden</span><span class="o">=</span><span class="n">num.hidden</span><span class="p">,</span><span class="w">
                                    </span><span class="n">name</span><span class="o">=</span><span class="n">paste0</span><span class="p">(</span><span class="s2">"t"</span><span class="p">,</span><span class="w"> </span><span class="n">seqidx</span><span class="p">,</span><span class="w"> </span><span class="s2">".l"</span><span class="p">,</span><span class="w"> </span><span class="n">layeridx</span><span class="p">,</span><span class="w"> </span><span class="s2">".i2h"</span><span class="p">))</span><span class="w">
    </span><span class="n">h</span><span class="m">2</span><span class="n">h</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">mx.symbol.FullyConnected</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="n">prev.state</span><span class="o">$</span><span class="n">h</span><span class="p">,</span><span class="w">
                                    </span><span class="n">weight</span><span class="o">=</span><span class="n">param</span><span class="o">$</span><span class="n">h</span><span class="m">2</span><span class="n">h.weight</span><span class="p">,</span><span class="w">
                                    </span><span class="n">bias</span><span class="o">=</span><span class="n">param</span><span class="o">$</span><span class="n">h</span><span class="m">2</span><span class="n">h.bias</span><span class="p">,</span><span class="w">
                                    </span><span class="n">num.hidden</span><span class="o">=</span><span class="n">num.hidden</span><span class="p">,</span><span class="w">
                                    </span><span class="n">name</span><span class="o">=</span><span class="n">paste0</span><span class="p">(</span><span class="s2">"t"</span><span class="p">,</span><span class="w"> </span><span class="n">seqidx</span><span class="p">,</span><span class="w"> </span><span class="s2">".l"</span><span class="p">,</span><span class="w"> </span><span class="n">layeridx</span><span class="p">,</span><span class="w"> </span><span class="s2">".h2h"</span><span class="p">))</span><span class="w">
    </span><span class="n">hidden</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">i</span><span class="m">2</span><span class="n">h</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="n">h</span><span class="m">2</span><span class="n">h</span><span class="w">

    </span><span class="n">hidden</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">mx.symbol.Activation</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="n">hidden</span><span class="p">,</span><span class="w"> </span><span class="n">act.type</span><span class="o">=</span><span class="s2">"tanh"</span><span class="p">)</span><span class="w">
    </span><span class="k">if</span><span class="w"> </span><span class="p">(</span><span class="n">batch.norm</span><span class="p">)</span><span class="w">
        </span><span class="n">hidden</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">mx.symbol.BatchNorm</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="n">hidden</span><span class="p">)</span><span class="w">
    </span><span class="n">return</span><span class="w"> </span><span class="p">(</span><span class="nf">list</span><span class="p">(</span><span class="n">h</span><span class="o">=</span><span class="n">hidden</span><span class="p">))</span><span class="w">
</span><span class="p">}</span><span class="w">
</span>

LSTM Cells

LSTM replace the cells in custom RNN with LSTM memory block. Figure 2 illustrates the architecture of an LSTM unit.

It contains one memory cell and three multiplicative units, i.e., the input gate, the forget gate and the output gate. With the help of the memory cell and the gates, LSTM can store and learn long term dependencies across the whole sequence.

A LSTM cell produces the next states by based on the input x as well as previous states (including c and h). For gates, there are three types of connections, including the input to the gate, the (previous) hidden to the gate and the cell to the gate. The activation functions of the gates should use the sigmoid function to make sure the outputs of gates in range [0, 1]. For the memory cell, there are two connections, the input to the cell and the (previous) hidden to the cell.

<span class="n">lstm</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="k">function</span><span class="p">(</span><span class="n">num.hidden</span><span class="p">,</span><span class="w"> </span><span class="n">indata</span><span class="p">,</span><span class="w"> </span><span class="n">prev.state</span><span class="p">,</span><span class="w"> </span><span class="n">param</span><span class="p">,</span><span class="w"> </span><span class="n">seqidx</span><span class="p">,</span><span class="w"> </span><span class="n">layeridx</span><span class="p">,</span><span class="w"> </span><span class="n">dropout</span><span class="o">=</span><span class="m">0</span><span class="p">)</span><span class="w"> </span><span class="p">{</span><span class="w">
    </span><span class="k">if</span><span class="w"> </span><span class="p">(</span><span class="n">dropout</span><span class="w"> </span><span class="o">></span><span class="w"> </span><span class="m">0</span><span class="p">)</span><span class="w">
        </span><span class="n">indata</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">mx.symbol.Dropout</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="n">indata</span><span class="p">,</span><span class="w"> </span><span class="n">p</span><span class="o">=</span><span class="n">dropout</span><span class="p">)</span><span class="w">
    </span><span class="n">i</span><span class="m">2</span><span class="n">h</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">mx.symbol.FullyConnected</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="n">indata</span><span class="p">,</span><span class="w">
                                    </span><span class="n">weight</span><span class="o">=</span><span class="n">param</span><span class="o">$</span><span class="n">i</span><span class="m">2</span><span class="n">h.weight</span><span class="p">,</span><span class="w">
                                    </span><span class="n">bias</span><span class="o">=</span><span class="n">param</span><span class="o">$</span><span class="n">i</span><span class="m">2</span><span class="n">h.bias</span><span class="p">,</span><span class="w">
                                    </span><span class="n">num.hidden</span><span class="o">=</span><span class="n">num.hidden</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="m">4</span><span class="p">,</span><span class="w">
                                    </span><span class="n">name</span><span class="o">=</span><span class="n">paste0</span><span class="p">(</span><span class="s2">"t"</span><span class="p">,</span><span class="w"> </span><span class="n">seqidx</span><span class="p">,</span><span class="w"> </span><span class="s2">".l"</span><span class="p">,</span><span class="w"> </span><span class="n">layeridx</span><span class="p">,</span><span class="w"> </span><span class="s2">".i2h"</span><span class="p">))</span><span class="w">
    </span><span class="n">h</span><span class="m">2</span><span class="n">h</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">mx.symbol.FullyConnected</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="n">prev.state</span><span class="o">$</span><span class="n">h</span><span class="p">,</span><span class="w">
                                    </span><span class="n">weight</span><span class="o">=</span><span class="n">param</span><span class="o">$</span><span class="n">h</span><span class="m">2</span><span class="n">h.weight</span><span class="p">,</span><span class="w">
                                    </span><span class="n">bias</span><span class="o">=</span><span class="n">param</span><span class="o">$</span><span class="n">h</span><span class="m">2</span><span class="n">h.bias</span><span class="p">,</span><span class="w">
                                    </span><span class="n">num.hidden</span><span class="o">=</span><span class="n">num.hidden</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="m">4</span><span class="p">,</span><span class="w">
                                    </span><span class="n">name</span><span class="o">=</span><span class="n">paste0</span><span class="p">(</span><span class="s2">"t"</span><span class="p">,</span><span class="w"> </span><span class="n">seqidx</span><span class="p">,</span><span class="w"> </span><span class="s2">".l"</span><span class="p">,</span><span class="w"> </span><span class="n">layeridx</span><span class="p">,</span><span class="w"> </span><span class="s2">".h2h"</span><span class="p">))</span><span class="w">
    </span><span class="n">gates</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">i</span><span class="m">2</span><span class="n">h</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="n">h</span><span class="m">2</span><span class="n">h</span><span class="w">
    </span><span class="n">slice.gates</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">mx.symbol.SliceChannel</span><span class="p">(</span><span class="n">gates</span><span class="p">,</span><span class="w"> </span><span class="n">num.outputs</span><span class="o">=</span><span class="m">4</span><span class="p">,</span><span class="w">
                                          </span><span class="n">name</span><span class="o">=</span><span class="n">paste0</span><span class="p">(</span><span class="s2">"t"</span><span class="p">,</span><span class="w"> </span><span class="n">seqidx</span><span class="p">,</span><span class="w"> </span><span class="s2">".l"</span><span class="p">,</span><span class="w"> </span><span class="n">layeridx</span><span class="p">,</span><span class="w"> </span><span class="s2">".slice"</span><span class="p">))</span><span class="w">

    </span><span class="n">in.gate</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">mx.symbol.Activation</span><span class="p">(</span><span class="n">slice.gates</span><span class="p">[[</span><span class="m">1</span><span class="p">]],</span><span class="w"> </span><span class="n">act.type</span><span class="o">=</span><span class="s2">"sigmoid"</span><span class="p">)</span><span class="w">
    </span><span class="n">in.transform</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">mx.symbol.Activation</span><span class="p">(</span><span class="n">slice.gates</span><span class="p">[[</span><span class="m">2</span><span class="p">]],</span><span class="w"> </span><span class="n">act.type</span><span class="o">=</span><span class="s2">"tanh"</span><span class="p">)</span><span class="w">
    </span><span class="n">forget.gate</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">mx.symbol.Activation</span><span class="p">(</span><span class="n">slice.gates</span><span class="p">[[</span><span class="m">3</span><span class="p">]],</span><span class="w"> </span><span class="n">act.type</span><span class="o">=</span><span class="s2">"sigmoid"</span><span class="p">)</span><span class="w">
    </span><span class="n">out.gate</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">mx.symbol.Activation</span><span class="p">(</span><span class="n">slice.gates</span><span class="p">[[</span><span class="m">4</span><span class="p">]],</span><span class="w"> </span><span class="n">act.type</span><span class="o">=</span><span class="s2">"sigmoid"</span><span class="p">)</span><span class="w">
    </span><span class="n">next.c</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="p">(</span><span class="n">forget.gate</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="n">prev.state</span><span class="o">$</span><span class="n">c</span><span class="p">)</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="p">(</span><span class="n">in.gate</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="n">in.transform</span><span class="p">)</span><span class="w">
    </span><span class="n">next.h</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">out.gate</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="n">mx.symbol.Activation</span><span class="p">(</span><span class="n">next.c</span><span class="p">,</span><span class="w"> </span><span class="n">act.type</span><span class="o">=</span><span class="s2">"tanh"</span><span class="p">)</span><span class="w">

    </span><span class="n">return</span><span class="w"> </span><span class="p">(</span><span class="nf">list</span><span class="p">(</span><span class="n">c</span><span class="o">=</span><span class="n">next.c</span><span class="p">,</span><span class="w"> </span><span class="n">h</span><span class="o">=</span><span class="n">next.h</span><span class="p">))</span><span class="w">
</span><span class="p">}</span><span class="w">
</span>

Instead of defining the gates and the memory cell independently, we compute them together and then use mx.symbol.SliceChannel to split them into four outputs.

GRU Cells

GRU is another variant model of RNN, which was proposed in 2014 [2]. Similar to LSTM unit, the GRU unit also aims to adaptively capture dependencies of different time scales by updating and reseting gate as shown in Figure 3.

The calculation of GRU is similar with the custom RNN and LSTM models. First, there are two types of connections for gates to be defined: update.gate decides how much the unit updates its activations and reset.gate which decides whether to forget the previously computed state. Then, the candidate activation htrans is computed using input data x and previous state h (reset.gate controls whether to forget the previous state). After getting htrans, use update.gate to decide the proportion of htrans and previous state h to calculate next state h.

<span class="n">gru</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="k">function</span><span class="p">(</span><span class="n">num.hidden</span><span class="p">,</span><span class="w"> </span><span class="n">indata</span><span class="p">,</span><span class="w"> </span><span class="n">prev.state</span><span class="p">,</span><span class="w"> </span><span class="n">param</span><span class="p">,</span><span class="w"> </span><span class="n">seqidx</span><span class="p">,</span><span class="w"> </span><span class="n">layeridx</span><span class="p">,</span><span class="w"> </span><span class="n">dropout</span><span class="o">=</span><span class="m">0</span><span class="p">)</span><span class="w"> </span><span class="p">{</span><span class="w">
    </span><span class="k">if</span><span class="w"> </span><span class="p">(</span><span class="n">dropout</span><span class="w"> </span><span class="o">></span><span class="w"> </span><span class="m">0</span><span class="p">)</span><span class="w">
        </span><span class="n">indata</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">mx.symbol.Dropout</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="n">indata</span><span class="p">,</span><span class="w"> </span><span class="n">p</span><span class="o">=</span><span class="n">dropout</span><span class="p">)</span><span class="w">
    </span><span class="n">i</span><span class="m">2</span><span class="n">h</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">mx.symbol.FullyConnected</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="n">indata</span><span class="p">,</span><span class="w">
                                    </span><span class="n">weight</span><span class="o">=</span><span class="n">param</span><span class="o">$</span><span class="n">gates.i2h.weight</span><span class="p">,</span><span class="w">
                                    </span><span class="n">bias</span><span class="o">=</span><span class="n">param</span><span class="o">$</span><span class="n">gates.i2h.bias</span><span class="p">,</span><span class="w">
                                    </span><span class="n">num.hidden</span><span class="o">=</span><span class="n">num.hidden</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="m">2</span><span class="p">,</span><span class="w">
                                    </span><span class="n">name</span><span class="o">=</span><span class="n">paste0</span><span class="p">(</span><span class="s2">"t"</span><span class="p">,</span><span class="w"> </span><span class="n">seqidx</span><span class="p">,</span><span class="w"> </span><span class="s2">".l"</span><span class="p">,</span><span class="w"> </span><span class="n">layeridx</span><span class="p">,</span><span class="w"> </span><span class="s2">".gates.i2h"</span><span class="p">))</span><span class="w">
    </span><span class="n">h</span><span class="m">2</span><span class="n">h</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">mx.symbol.FullyConnected</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="n">prev.state</span><span class="o">$</span><span class="n">h</span><span class="p">,</span><span class="w">
                                    </span><span class="n">weight</span><span class="o">=</span><span class="n">param</span><span class="o">$</span><span class="n">gates.h2h.weight</span><span class="p">,</span><span class="w">
                                    </span><span class="n">bias</span><span class="o">=</span><span class="n">param</span><span class="o">$</span><span class="n">gates.h2h.bias</span><span class="p">,</span><span class="w">
                                    </span><span class="n">num.hidden</span><span class="o">=</span><span class="n">num.hidden</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="m">2</span><span class="p">,</span><span class="w">
                                    </span><span class="n">name</span><span class="o">=</span><span class="n">paste0</span><span class="p">(</span><span class="s2">"t"</span><span class="p">,</span><span class="w"> </span><span class="n">seqidx</span><span class="p">,</span><span class="w"> </span><span class="s2">".l"</span><span class="p">,</span><span class="w"> </span><span class="n">layeridx</span><span class="p">,</span><span class="w"> </span><span class="s2">".gates.h2h"</span><span class="p">))</span><span class="w">
    </span><span class="n">gates</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">i</span><span class="m">2</span><span class="n">h</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="n">h</span><span class="m">2</span><span class="n">h</span><span class="w">
    </span><span class="n">slice.gates</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">mx.symbol.SliceChannel</span><span class="p">(</span><span class="n">gates</span><span class="p">,</span><span class="w"> </span><span class="n">num.outputs</span><span class="o">=</span><span class="m">2</span><span class="p">,</span><span class="w">
                                          </span><span class="n">name</span><span class="o">=</span><span class="n">paste0</span><span class="p">(</span><span class="s2">"t"</span><span class="p">,</span><span class="w"> </span><span class="n">seqidx</span><span class="p">,</span><span class="w"> </span><span class="s2">".l"</span><span class="p">,</span><span class="w"> </span><span class="n">layeridx</span><span class="p">,</span><span class="w"> </span><span class="s2">".slice"</span><span class="p">))</span><span class="w">
    </span><span class="n">update.gate</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">mx.symbol.Activation</span><span class="p">(</span><span class="n">slice.gates</span><span class="p">[[</span><span class="m">1</span><span class="p">]],</span><span class="w"> </span><span class="n">act.type</span><span class="o">=</span><span class="s2">"sigmoid"</span><span class="p">)</span><span class="w">
    </span><span class="n">reset.gate</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">mx.symbol.Activation</span><span class="p">(</span><span class="n">slice.gates</span><span class="p">[[</span><span class="m">2</span><span class="p">]],</span><span class="w"> </span><span class="n">act.type</span><span class="o">=</span><span class="s2">"sigmoid"</span><span class="p">)</span><span class="w">

    </span><span class="n">htrans.i2h</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">mx.symbol.FullyConnected</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="n">indata</span><span class="p">,</span><span class="w">
                                           </span><span class="n">weight</span><span class="o">=</span><span class="n">param</span><span class="o">$</span><span class="n">trans.i2h.weight</span><span class="p">,</span><span class="w">
                                           </span><span class="n">bias</span><span class="o">=</span><span class="n">param</span><span class="o">$</span><span class="n">trans.i2h.bias</span><span class="p">,</span><span class="w">
                                           </span><span class="n">num.hidden</span><span class="o">=</span><span class="n">num.hidden</span><span class="p">,</span><span class="w">
                                           </span><span class="n">name</span><span class="o">=</span><span class="n">paste0</span><span class="p">(</span><span class="s2">"t"</span><span class="p">,</span><span class="w"> </span><span class="n">seqidx</span><span class="p">,</span><span class="w"> </span><span class="s2">".l"</span><span class="p">,</span><span class="w"> </span><span class="n">layeridx</span><span class="p">,</span><span class="w"> </span><span class="s2">".trans.i2h"</span><span class="p">))</span><span class="w">
    </span><span class="n">h.after.reset</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">prev.state</span><span class="o">$</span><span class="n">h</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="n">reset.gate</span><span class="w">
    </span><span class="n">htrans.h2h</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">mx.symbol.FullyConnected</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="n">h.after.reset</span><span class="p">,</span><span class="w">
                                           </span><span class="n">weight</span><span class="o">=</span><span class="n">param</span><span class="o">$</span><span class="n">trans.h2h.weight</span><span class="p">,</span><span class="w">
                                           </span><span class="n">bias</span><span class="o">=</span><span class="n">param</span><span class="o">$</span><span class="n">trans.h2h.bias</span><span class="p">,</span><span class="w">
                                           </span><span class="n">num.hidden</span><span class="o">=</span><span class="n">num.hidden</span><span class="p">,</span><span class="w">
                                           </span><span class="n">name</span><span class="o">=</span><span class="n">paste0</span><span class="p">(</span><span class="s2">"t"</span><span class="p">,</span><span class="w"> </span><span class="n">seqidx</span><span class="p">,</span><span class="w"> </span><span class="s2">".l"</span><span class="p">,</span><span class="w"> </span><span class="n">layeridx</span><span class="p">,</span><span class="w"> </span><span class="s2">".trans.h2h"</span><span class="p">))</span><span class="w">
    </span><span class="n">h.trans</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">htrans.i2h</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="n">htrans.h2h</span><span class="w">
    </span><span class="n">h.trans.active</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">mx.symbol.Activation</span><span class="p">(</span><span class="n">h.trans</span><span class="p">,</span><span class="w"> </span><span class="n">act.type</span><span class="o">=</span><span class="s2">"tanh"</span><span class="p">)</span><span class="w">
    </span><span class="n">next.h</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">prev.state</span><span class="o">$</span><span class="n">h</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="n">update.gate</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="p">(</span><span class="n">h.trans.active</span><span class="w"> </span><span class="o">-</span><span class="w"> </span><span class="n">prev.state</span><span class="o">$</span><span class="n">h</span><span class="p">)</span><span class="w">
    </span><span class="n">return</span><span class="w"> </span><span class="p">(</span><span class="nf">list</span><span class="p">(</span><span class="n">h</span><span class="o">=</span><span class="n">next.h</span><span class="p">))</span><span class="w">
</span><span class="p">}</span><span class="w">
</span>

Unrolling RNN Models

Since MXNet has implemented low level layers, we can unroll the RNN model in the time dimension and then use the MXNet layers to construct the different RNN networks according the above defined RNN cells. After unrolling the RNN in time, the model is just like the common feedforward network unless the weights shared in different feedforward layers and specific different recurrent units. Specifically, we can use fully-connected layers and corresponding activations to represent the different types of connections.

The unroll function needs to unroll the recurrent model according to the predefined sequence length. The recurrent weights across time need to be shared and the network depth represents the number of recurrent layers.

We provide the unrolling functions which is suitable for tasks like character language model and PennTreeBank language model. For other tasks the unrolling functions are similiar.

First, we define the weights and states. embed.weight is the weights used for embedding layer, mapping from the one-hot input to a dense vector. cls.weight and cls.bias are the weights and bias for the final prediction at each time step. param.cells and last.states is the weights and states for each cell. The weights are shared and the states are different and will be updates over time.

<span class="n">lstm.unroll</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="k">function</span><span class="p">(</span><span class="n">num.lstm.layer</span><span class="p">,</span><span class="w"> </span><span class="n">seq.len</span><span class="p">,</span><span class="w"> </span><span class="n">input.size</span><span class="p">,</span><span class="w">
                        </span><span class="n">num.hidden</span><span class="p">,</span><span class="w"> </span><span class="n">num.embed</span><span class="p">,</span><span class="w"> </span><span class="n">num.label</span><span class="p">,</span><span class="w"> </span><span class="n">dropout</span><span class="o">=</span><span class="m">0</span><span class="n">.</span><span class="p">)</span><span class="w"> </span><span class="p">{</span><span class="w">

    </span><span class="n">embed.weight</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">mx.symbol.Variable</span><span class="p">(</span><span class="s2">"embed.weight"</span><span class="p">)</span><span class="w">
    </span><span class="n">cls.weight</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">mx.symbol.Variable</span><span class="p">(</span><span class="s2">"cls.weight"</span><span class="p">)</span><span class="w">
    </span><span class="n">cls.bias</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">mx.symbol.Variable</span><span class="p">(</span><span class="s2">"cls.bias"</span><span class="p">)</span><span class="w">

    </span><span class="n">param.cells</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">lapply</span><span class="p">(</span><span class="m">1</span><span class="o">:</span><span class="n">num.lstm.layer</span><span class="p">,</span><span class="w"> </span><span class="k">function</span><span class="p">(</span><span class="n">i</span><span class="p">)</span><span class="w"> </span><span class="p">{</span><span class="w">
        </span><span class="n">cell</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="nf">list</span><span class="p">(</span><span class="n">i</span><span class="m">2</span><span class="n">h.weight</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">mx.symbol.Variable</span><span class="p">(</span><span class="n">paste0</span><span class="p">(</span><span class="s2">"l"</span><span class="p">,</span><span class="w"> </span><span class="n">i</span><span class="p">,</span><span class="w"> </span><span class="s2">".i2h.weight"</span><span class="p">)),</span><span class="w">
                     </span><span class="n">i</span><span class="m">2</span><span class="n">h.bias</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">mx.symbol.Variable</span><span class="p">(</span><span class="n">paste0</span><span class="p">(</span><span class="s2">"l"</span><span class="p">,</span><span class="w"> </span><span class="n">i</span><span class="p">,</span><span class="w"> </span><span class="s2">".i2h.bias"</span><span class="p">)),</span><span class="w">
                     </span><span class="n">h</span><span class="m">2</span><span class="n">h.weight</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">mx.symbol.Variable</span><span class="p">(</span><span class="n">paste0</span><span class="p">(</span><span class="s2">"l"</span><span class="p">,</span><span class="w"> </span><span class="n">i</span><span class="p">,</span><span class="w"> </span><span class="s2">".h2h.weight"</span><span class="p">)),</span><span class="w">
                     </span><span class="n">h</span><span class="m">2</span><span class="n">h.bias</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">mx.symbol.Variable</span><span class="p">(</span><span class="n">paste0</span><span class="p">(</span><span class="s2">"l"</span><span class="p">,</span><span class="w"> </span><span class="n">i</span><span class="p">,</span><span class="w"> </span><span class="s2">".h2h.bias"</span><span class="p">)))</span><span class="w">
        </span><span class="n">return</span><span class="w"> </span><span class="p">(</span><span class="n">cell</span><span class="p">)</span><span class="w">
    </span><span class="p">})</span><span class="w">
    </span><span class="n">last.states</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">lapply</span><span class="p">(</span><span class="m">1</span><span class="o">:</span><span class="n">num.lstm.layer</span><span class="p">,</span><span class="w"> </span><span class="k">function</span><span class="p">(</span><span class="n">i</span><span class="p">)</span><span class="w"> </span><span class="p">{</span><span class="w">
        </span><span class="n">state</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="nf">list</span><span class="p">(</span><span class="n">c</span><span class="o">=</span><span class="n">mx.symbol.Variable</span><span class="p">(</span><span class="n">paste0</span><span class="p">(</span><span class="s2">"l"</span><span class="p">,</span><span class="w"> </span><span class="n">i</span><span class="p">,</span><span class="w"> </span><span class="s2">".init.c"</span><span class="p">)),</span><span class="w">
                      </span><span class="n">h</span><span class="o">=</span><span class="n">mx.symbol.Variable</span><span class="p">(</span><span class="n">paste0</span><span class="p">(</span><span class="s2">"l"</span><span class="p">,</span><span class="w"> </span><span class="n">i</span><span class="p">,</span><span class="w"> </span><span class="s2">".init.h"</span><span class="p">)))</span><span class="w">
        </span><span class="n">return</span><span class="w"> </span><span class="p">(</span><span class="n">state</span><span class="p">)</span><span class="w">
    </span><span class="p">})</span><span class="w">
</span>

Then we unroll the RNN model in the time dimension and then use the MXNet layers to construct the RNN networks. Here the mx.symbol.Embedding is used to get the embedding vector for the specific task (char-rnn) shown there. At each time step, we share the weights param.cells and update the states last.states. Also, last.hidden is used to collect the outputs over time.

<span class="w">    </span><span class="c1"># embeding layer</span><span class="w">
    </span><span class="n">label</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">mx.symbol.Variable</span><span class="p">(</span><span class="s2">"label"</span><span class="p">)</span><span class="w">
    </span><span class="n">data</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">mx.symbol.Variable</span><span class="p">(</span><span class="s2">"data"</span><span class="p">)</span><span class="w">
    </span><span class="n">embed</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">mx.symbol.Embedding</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="n">data</span><span class="p">,</span><span class="w"> </span><span class="n">input_dim</span><span class="o">=</span><span class="n">input.size</span><span class="p">,</span><span class="w">
                                 </span><span class="n">weight</span><span class="o">=</span><span class="n">embed.weight</span><span class="p">,</span><span class="w"> </span><span class="n">output_dim</span><span class="o">=</span><span class="n">num.embed</span><span class="p">,</span><span class="w"> </span><span class="n">name</span><span class="o">=</span><span class="s2">"embed"</span><span class="p">)</span><span class="w">
    </span><span class="n">wordvec</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">mx.symbol.SliceChannel</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="n">embed</span><span class="p">,</span><span class="w"> </span><span class="n">num_outputs</span><span class="o">=</span><span class="n">seq.len</span><span class="p">,</span><span class="w"> </span><span class="n">squeeze_axis</span><span class="o">=</span><span class="m">1</span><span class="p">)</span><span class="w">

    </span><span class="n">last.hidden</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="nf">list</span><span class="p">()</span><span class="w">
    </span><span class="k">for</span><span class="w"> </span><span class="p">(</span><span class="n">seqidx</span><span class="w"> </span><span class="k">in</span><span class="w"> </span><span class="m">1</span><span class="o">:</span><span class="n">seq.len</span><span class="p">)</span><span class="w"> </span><span class="p">{</span><span class="w">
        </span><span class="n">hidden</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">wordvec</span><span class="p">[[</span><span class="n">seqidx</span><span class="p">]]</span><span class="w">
        </span><span class="c1"># stack lstm</span><span class="w">
        </span><span class="k">for</span><span class="w"> </span><span class="p">(</span><span class="n">i</span><span class="w"> </span><span class="k">in</span><span class="w"> </span><span class="m">1</span><span class="o">:</span><span class="n">num.lstm.layer</span><span class="p">)</span><span class="w"> </span><span class="p">{</span><span class="w">
            </span><span class="n">dp</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">ifelse</span><span class="p">(</span><span class="n">i</span><span class="o">==</span><span class="m">1</span><span class="p">,</span><span class="w"> </span><span class="m">0</span><span class="p">,</span><span class="w"> </span><span class="n">dropout</span><span class="p">)</span><span class="w">
            </span><span class="n">next.state</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">lstm</span><span class="p">(</span><span class="n">num.hidden</span><span class="p">,</span><span class="w"> </span><span class="n">indata</span><span class="o">=</span><span class="n">hidden</span><span class="p">,</span><span class="w">
                               </span><span class="n">prev.state</span><span class="o">=</span><span class="n">last.states</span><span class="p">[[</span><span class="n">i</span><span class="p">]],</span><span class="w">
                               </span><span class="n">param</span><span class="o">=</span><span class="n">param.cells</span><span class="p">[[</span><span class="n">i</span><span class="p">]],</span><span class="w">
                               </span><span class="n">seqidx</span><span class="o">=</span><span class="n">seqidx</span><span class="p">,</span><span class="w"> </span><span class="n">layeridx</span><span class="o">=</span><span class="n">i</span><span class="p">,</span><span class="w">
                               </span><span class="n">dropout</span><span class="o">=</span><span class="n">dp</span><span class="p">)</span><span class="w">
            </span><span class="n">hidden</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">next.state</span><span class="o">$</span><span class="n">h</span><span class="w">
            </span><span class="n">last.states</span><span class="p">[[</span><span class="n">i</span><span class="p">]]</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">next.state</span><span class="w">
        </span><span class="p">}</span><span class="w">
        </span><span class="c1"># decoder</span><span class="w">
        </span><span class="k">if</span><span class="w"> </span><span class="p">(</span><span class="n">dropout</span><span class="w"> </span><span class="o">></span><span class="w"> </span><span class="m">0</span><span class="p">)</span><span class="w">
            </span><span class="n">hidden</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">mx.symbol.Dropout</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="n">hidden</span><span class="p">,</span><span class="w"> </span><span class="n">p</span><span class="o">=</span><span class="n">dropout</span><span class="p">)</span><span class="w">
        </span><span class="n">last.hidden</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="nf">c</span><span class="p">(</span><span class="n">last.hidden</span><span class="p">,</span><span class="w"> </span><span class="n">hidden</span><span class="p">)</span><span class="w">
    </span><span class="p">}</span><span class="w">
</span>

Finally, we need to construct the remain layers according to the different tasks. Take char-rnn as an example, cls.weight and cls.bias are used for the final prediction and then mx.symbol.SoftmaxOutput connnets the prediction and corresponding labels to back propagate though time.

<span class="w">    </span><span class="n">last.hidden</span><span class="o">$</span><span class="n">dim</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="m">0</span><span class="w">
    </span><span class="n">last.hidden</span><span class="o">$</span><span class="n">num.args</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">seq.len</span><span class="w">
    </span><span class="n">concat</span><span class="w"> </span><span class="o"><-</span><span class="n">mxnet</span><span class="o">:::</span><span class="n">mx.varg.symbol.Concat</span><span class="p">(</span><span class="n">last.hidden</span><span class="p">)</span><span class="w">
    </span><span class="n">fc</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">mx.symbol.FullyConnected</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="n">concat</span><span class="p">,</span><span class="w">
                                   </span><span class="n">weight</span><span class="o">=</span><span class="n">cls.weight</span><span class="p">,</span><span class="w">
                                   </span><span class="n">bias</span><span class="o">=</span><span class="n">cls.bias</span><span class="p">,</span><span class="w">
                                   </span><span class="n">num.hidden</span><span class="o">=</span><span class="n">num.label</span><span class="p">)</span><span class="w">

    </span><span class="n">label</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">mx.symbol.transpose</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="n">label</span><span class="p">)</span><span class="w">
    </span><span class="n">label</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">mx.symbol.Reshape</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="n">label</span><span class="p">,</span><span class="w"> </span><span class="n">target.shape</span><span class="o">=</span><span class="nf">c</span><span class="p">(</span><span class="m">0</span><span class="p">))</span><span class="w">

    </span><span class="n">loss.all</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">mx.symbol.SoftmaxOutput</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="n">fc</span><span class="p">,</span><span class="w"> </span><span class="n">label</span><span class="o">=</span><span class="n">label</span><span class="p">,</span><span class="w"> </span><span class="n">name</span><span class="o">=</span><span class="s2">"sm"</span><span class="p">)</span><span class="w">
    </span><span class="n">return</span><span class="w"> </span><span class="p">(</span><span class="n">loss.all</span><span class="p">)</span><span class="w">
</span><span class="p">}</span><span class="w">
</span>

Training RNN Models

After implementing unrolled RNN models, we need to know how to train the recurrent models. We write the training code on our own using the low level symbol interfaces instead of the common FeedForward model. The reason why we implement by ourselves is two-fold. First, currently, existing data iterators do not support sequence data well. Second, since our input contains both the input data x and the states i.e h and we need set the gradients of them to zero at each epoch, the existing high level FeedForward interface is not appropriate for our tasks.

For the reasons above, we write the training code on our own to bind the network and train epoch by epoch using the low level symbol interfaces. You can refer to the code in rnn_model.R. The training method of three RNN models are the same except the initial states are init.h or init.c and init.h. So they can use the same training function.

The training codes mainly contain the function setup.rnn.model and the function train.rnn.

Set Up RNN Models

setup.rnn.model is defined to initialize parameters and bind the network. The parameter init.states.name stores the name of the initial state. For custom RNN and GRU, init.states.name is end with init.h. For LSTM, init.states.name is end with init.h and init.c.

First, we set the dimension of the input including input data, label and initial states. Given the input dimension, we can use mx.model.init.params.rnn function to initialize parameters.

<span class="n">setup.rnn.model</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="k">function</span><span class="p">(</span><span class="n">rnn.sym</span><span class="p">,</span><span class="w"> </span><span class="n">ctx</span><span class="p">,</span><span class="w">
                            </span><span class="n">num.rnn.layer</span><span class="p">,</span><span class="w"> </span><span class="n">seq.len</span><span class="p">,</span><span class="w">
                            </span><span class="n">num.hidden</span><span class="p">,</span><span class="w"> </span><span class="n">num.embed</span><span class="p">,</span><span class="w"> </span><span class="n">num.label</span><span class="p">,</span><span class="w">
                            </span><span class="n">batch.size</span><span class="p">,</span><span class="w"> </span><span class="n">input.size</span><span class="p">,</span><span class="w">
                            </span><span class="n">init.states.name</span><span class="p">,</span><span class="w">
                            </span><span class="n">initializer</span><span class="o">=</span><span class="n">mx.init.uniform</span><span class="p">(</span><span class="m">0.01</span><span class="p">),</span><span class="w">
                            </span><span class="n">dropout</span><span class="o">=</span><span class="m">0</span><span class="p">)</span><span class="w"> </span><span class="p">{</span><span class="w">

    </span><span class="n">arg.names</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">rnn.sym</span><span class="o">$</span><span class="n">arguments</span><span class="w">
    </span><span class="n">input.shapes</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="nf">list</span><span class="p">()</span><span class="w">
    </span><span class="k">for</span><span class="w"> </span><span class="p">(</span><span class="n">name</span><span class="w"> </span><span class="k">in</span><span class="w"> </span><span class="n">arg.names</span><span class="p">)</span><span class="w"> </span><span class="p">{</span><span class="w">
        </span><span class="k">if</span><span class="w"> </span><span class="p">(</span><span class="n">name</span><span class="w"> </span><span class="o">%in%</span><span class="w"> </span><span class="n">init.states.name</span><span class="p">)</span><span class="w"> </span><span class="p">{</span><span class="w">
            </span><span class="n">input.shapes</span><span class="p">[[</span><span class="n">name</span><span class="p">]]</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="nf">c</span><span class="p">(</span><span class="n">num.hidden</span><span class="p">,</span><span class="w"> </span><span class="n">batch.size</span><span class="p">)</span><span class="w">
        </span><span class="p">}</span><span class="w">
        </span><span class="k">else</span><span class="w"> </span><span class="k">if</span><span class="w"> </span><span class="p">(</span><span class="n">grepl</span><span class="p">(</span><span class="s1">'data$'</span><span class="p">,</span><span class="w"> </span><span class="n">name</span><span class="p">)</span><span class="w"> </span><span class="o">||</span><span class="w"> </span><span class="n">grepl</span><span class="p">(</span><span class="s1">'label$'</span><span class="p">,</span><span class="w"> </span><span class="n">name</span><span class="p">)</span><span class="w"> </span><span class="p">)</span><span class="w"> </span><span class="p">{</span><span class="w">
            </span><span class="k">if</span><span class="w"> </span><span class="p">(</span><span class="n">seq.len</span><span class="w"> </span><span class="o">==</span><span class="w"> </span><span class="m">1</span><span class="p">)</span><span class="w"> </span><span class="p">{</span><span class="w">
                </span><span class="n">input.shapes</span><span class="p">[[</span><span class="n">name</span><span class="p">]]</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="nf">c</span><span class="p">(</span><span class="n">batch.size</span><span class="p">)</span><span class="w">
            </span><span class="p">}</span><span class="w"> </span><span class="k">else</span><span class="w"> </span><span class="p">{</span><span class="w">
            </span><span class="n">input.shapes</span><span class="p">[[</span><span class="n">name</span><span class="p">]]</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="nf">c</span><span class="p">(</span><span class="n">seq.len</span><span class="p">,</span><span class="w"> </span><span class="n">batch.size</span><span class="p">)</span><span class="w">
            </span><span class="p">}</span><span class="w">
        </span><span class="p">}</span><span class="w">
    </span><span class="p">}</span><span class="w">
    </span><span class="n">params</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">mx.model.init.params.rnn</span><span class="p">(</span><span class="n">rnn.sym</span><span class="p">,</span><span class="w"> </span><span class="n">input.shapes</span><span class="p">,</span><span class="w"> </span><span class="n">initializer</span><span class="p">,</span><span class="w"> </span><span class="n">mx.cpu</span><span class="p">())</span><span class="w">
</span>

Next, we use mx.simple.bind to bind the network, set the arg.arrays, aux.arrays and grad.arrays.

<span class="w">    </span><span class="n">args</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">input.shapes</span><span class="w">
    </span><span class="n">args</span><span class="o">$</span><span class="n">symbol</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">rnn.sym</span><span class="w">
    </span><span class="n">args</span><span class="o">$</span><span class="n">ctx</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">ctx</span><span class="w">
    </span><span class="n">args</span><span class="o">$</span><span class="n">grad.req</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="s2">"add"</span><span class="w">
    </span><span class="n">rnn.exec</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">do.call</span><span class="p">(</span><span class="n">mx.simple.bind</span><span class="p">,</span><span class="w"> </span><span class="n">args</span><span class="p">)</span><span class="w">

    </span><span class="n">mx.exec.update.arg.arrays</span><span class="p">(</span><span class="n">rnn.exec</span><span class="p">,</span><span class="w"> </span><span class="n">params</span><span class="o">$</span><span class="n">arg.params</span><span class="p"...

To leave a comment for the author, please follow the link and comment on their blog: DMLC.

R-bloggers.com offers daily e-mail updates about R news and tutorials about learning R and many other topics. Click here if you're looking to post or find an R/data-science job.
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.

Never miss an update!
Subscribe to R-bloggers to receive
e-mails with the latest R posts.
(You will not see this message again.)

Click here to close (This popup will not appear again)