We want to do matrix multiplication for 3 cases:
dense times dense
sparse times dense for sparse matrices of class dgCMatrix
sparse times dense for sparse matrices of class indMatrix,
using R’s Matrix package for sparse matrices in R and
RcppArmadillo for C++ linear algebra:
<span>// [[Rcpp::depends(RcppArmadillo)]]</span>
<span>#include <RcppArmadillo.h></span>
<span>using</span> <span>namespace</span> <span>Rcpp</span> <span>;</span>
<span>arma</span><span>::</span><span>mat</span> <span>matmult_sp</span><span>(</span><span>const</span> <span>arma</span><span>::</span><span>sp_mat</span> <span>X</span><span>,</span> <span>const</span> <span>arma</span><span>::</span><span>mat</span> <span>Y</span><span>){</span>
<span>arma</span><span>::</span><span>mat</span> <span>ret</span> <span>=</span> <span>X</span> <span>*</span> <span>Y</span><span>;</span>
<span>return</span> <span>ret</span><span>;</span>
<span>};</span>
<span>arma</span><span>::</span><span>mat</span> <span>matmult_dense</span><span>(</span><span>const</span> <span>arma</span><span>::</span><span>mat</span> <span>X</span><span>,</span> <span>const</span> <span>arma</span><span>::</span><span>mat</span> <span>Y</span><span>){</span>
<span>arma</span><span>::</span><span>mat</span> <span>ret</span> <span>=</span> <span>X</span> <span>*</span> <span>Y</span><span>;</span>
<span>return</span> <span>ret</span><span>;</span>
<span>};</span>
<span>arma</span><span>::</span><span>mat</span> <span>matmult_ind</span><span>(</span><span>const</span> <span>SEXP</span> <span>Xr</span><span>,</span> <span>const</span> <span>arma</span><span>::</span><span>mat</span> <span>Y</span><span>){</span>
<span>// pre-multiplication with index matrix is a permutation of Y's rows: </span>
<span>arma</span><span>::</span><span>uvec</span> <span>perm</span> <span>=</span> <span>as</span><span><</span><span>S4</span><span>></span><span>(</span><span>Xr</span><span>).</span><span>slot</span><span>(</span><span>"perm"</span><span>);</span>
<span>arma</span><span>::</span><span>mat</span> <span>ret</span> <span>=</span> <span>Y</span><span>.</span><span>rows</span><span>(</span><span>perm</span> <span>-</span> <span>1</span><span>);</span>
<span>return</span> <span>ret</span><span>;</span>
<span>};</span>
<span>//[[Rcpp::export]]</span>
<span>arma</span><span>::</span><span>mat</span> <span>matmult_cpp</span><span>(</span><span>SEXP</span> <span>Xr</span><span>,</span> <span>const</span> <span>arma</span><span>::</span><span>mat</span> <span>Y</span><span>)</span> <span>{</span>
<span>if</span> <span>(</span><span>Rf_isS4</span><span>(</span><span>Xr</span><span>))</span> <span>{</span>
<span>if</span><span>(</span><span>Rf_inherits</span><span>(</span><span>Xr</span><span>,</span> <span>"dgCMatrix"</span><span>))</span> <span>{</span>
<span>return</span> <span>matmult_sp</span><span>(</span><span>as</span><span><</span><span>arma</span><span>::</span><span>sp_mat</span><span>></span><span>(</span><span>Xr</span><span>),</span> <span>Y</span><span>)</span> <span>;</span>
<span>}</span> <span>;</span>
<span>if</span><span>(</span><span>Rf_inherits</span><span>(</span><span>Xr</span><span>,</span> <span>"indMatrix"</span><span>))</span> <span>{</span>
<span>return</span> <span>matmult_ind</span><span>(</span><span>Xr</span><span>,</span> <span>Y</span><span>)</span> <span>;</span>
<span>}</span> <span>;</span>
<span>stop</span><span>(</span><span>"unknown class of Xr"</span><span>)</span> <span>;</span>
<span>}</span> <span>else</span> <span>{</span>
<span>return</span> <span>matmult_dense</span><span>(</span><span>as</span><span><</span><span>arma</span><span>::</span><span>mat</span><span>></span><span>(</span><span>Xr</span><span>),</span> <span>Y</span><span>)</span> <span>;</span>
<span>}</span>
<span>}</span>
Set up test cases:
{{...