Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.

There is a great RcppArmadillo implementation of multivariate normal densities. But I was looking for the first derivative of the multivariate normal densities. Good implementations are surprisingly hard to come by. I wasn’t able to find any online and my first solutions in R were pretty slow. RcppArmadillo might be a great alternative particularly because I am not aware of any c or Fortran versions that can be called from R. In such situations, we can expect the largest performance gains. Indeed, the RcppArmadillo version presented below is over 300-times faster than the R implementation!

Let us start with some R code. First, dmvnorm_deriv1 is a simple R implementation of the formula shown in the Matrix Cookbook (formula 346 and 347, Nov 15, 2012 version). The second version dmvnorm_deriv2 extends Peter Rossi’s implementation of the multivariate normal densities in his package bayesm, which uses a much faster algorithm and can be used to improve our implementation of the first derivative.

library('RcppArmadillo',quietly=TRUE)
library('rbenchmark',quietly=TRUE)
library('mvtnorm',quietly=TRUE)

dmvnorm_deriv1 <- function(X, mu=rep(0,ncol(X)), sigma=diag(ncol(X))) {
fn <- function(x) -1 * c((1/sqrt(det(2*pi*sigma))) * exp(-0.5*t(x-mu)%*%solve(sigma)%*%(x-mu))) * solve(sigma,(x-mu))
out <- t(apply(X,1,fn))
return(out)
}

# mv normal density based on Peter Rossi's implementation in bayesm
dMvn <- function(X,mu,Sigma) {
k <- ncol(X)
rooti <- backsolve(chol(Sigma),diag(k))
}

dmvnorm_deriv2 <- function(X, mean, sigma) {
if (is.vector(X)) X <- matrix(X, ncol = length(X))
if (missing(mean)) mean <- rep(0, length = ncol(X))
if (missing(sigma)) sigma <- diag(ncol(X))
n <- nrow(X)
mvnorm <- dMvn(X, mu = mean, Sigma = sigma)
deriv <- array(NA,c(n,ncol(X)))
for (i in 1:n)
deriv[i,] <- -mvnorm[i] * solve(sigma,(X[i,]-mean))
return(deriv)
}


These implementations work but they are not very fast. dmvnorm_deriv1 is a one-to-one translation of the formula in pure R and more efficient algorithms exist. dmvnorm_deriv2 uses such an algorithm from the bayesm package and is significantly faster. As shown in this gallery post, a translation of this algorithm to RcppArmadillo leads to further performance improvements. But I assume that the real bottleneck is the loop for the calculation of the derivative. So let’s adopt the calculation of the multivariate normal from the gallery post and translate the loop to RcppArmadillo. After some cleaning, we get a nice RcppArmadillo implementation called dmvnorm_deriv_arma.

#include <RcppArmadillo.h>

const double log2pi = std::log(2.0 * M_PI);

// [[Rcpp::export]]
arma::mat dmvnorm_deriv_arma(arma::mat x,
arma::rowvec mean,
arma::mat sigma) {

int xdim = x.n_cols;
arma::mat deriv;
deriv.copy_size(x);
arma::mat rooti = arma::trans(arma::inv(trimatu(arma::chol(sigma))));
double rootisum = arma::sum(log(rooti.diag()));
double constants = -(xdim/2) * log2pi;

int n = x.n_rows;
for (int i=0; i < n; i++) {
arma::vec x_centered = arma::trans(x.row(i) - mean);
arma::vec z = rooti * x_centered;
// get derivative of multivariate normal
deriv.row(i) = -1 * exp(constants - 0.5 * arma::sum(z%z) + rootisum) * trans(solve(sigma, x_centered));
// The part exp(constants - 0.5 * arma::sum(z%z) + rootisum) returns the multivarate normal and the other terms translate it to the first derivate
}

return(deriv);
}


Finally, we can compare the different versions using simulated data.

set.seed(123456789)
s <- rWishart(1, 2, diag(2))[,,1]
m <- rnorm(2)
X <- rmvnorm(10000, m, s)

benchmark(dmvnorm_deriv_arma(X,m,s),
dmvnorm_deriv1(X,mu=m,sigma=s),
dmvnorm_deriv2(X,mean=m,sigma=s),
order="relative", replications=50)[,1:4]

test replications elapsed relative
1            dmvnorm_deriv_arma(X, m, s)           50   0.108      1.0
3 dmvnorm_deriv2(X, mean = m, sigma = s)           50  23.085    213.8
2   dmvnorm_deriv1(X, mu = m, sigma = s)           50  75.973    703.5


The RcppArmadillo implementation is several hundred times faster! Such stunning performance increases are possible when existing implementation rely on pure R. Of course, the R implementation can be improved as well. 