Cross-Fitting Double Machine Learning estimator

June 28, 2017

(This article was first published on R – insightR, and kindly contributed to R-bloggers)

By Gabriel Vasconcelos


In a late post I talked about inference after model selection showing that a simple double selection procedure is enough to solve the problem. In this post I’m going to talk about a generalization of the double selection for any Machine Learning (ML) method described by Chernozhukov et al. (2016).

Suppose you are in the following framework:

\displaystyle y_i=d_i\theta +g_0(\boldsymbol{z}_i)+u_i
\displaystyle d_i=m_0(\boldsymbol{z}_i)+v_i

where \theta is the parameter of interest, \boldsymbol{z}_i is a set of control variables and u_i and v_i are error terms. The functions m_0 and g_0 are very generic and possibly non-linear.

How can we estimate \theta? The most naive way (and obviously wrong) is to simple estimate a regression using only d_i to explain y_i. Another way is to try something similar to the double selection, which is referred by Chernozhukov et al. (2016) as Double Machine Learning (DML):

  • (1): Estimate d_i=\hat{m}_0(\boldsymbol{z}_i)+\hat{v}_i,
  • (2): Estimate y_i=\hat{g}_0(\boldsymbol{z}_i)+\hat{u}_i without using d_i,
  • (3): Estimate \hat{\theta}=(\sum_{i=1}^N \hat{v}_id_i)^{-1}\sum_{i=1}^N \hat{v}_i (y_i-\hat{g}_0(\boldsymbol{z}_i)).

However, in this case the procedure above will leave a term on the asymptotic distribution of \hat{\theta} that will cause bias. This problem may be solved with sample splitting. Suppose we randomly split our N observations in two samples of size n=N/2. The fist sample will be indexed by I and the auxiliary second sample will be indexed by I^c. We are going to estimate the first two steps above in the auxiliary sample I^c and the third step will be done into sample I. Therefore:

\displaystyle \hat{\theta}=\left(\sum_{i=\in I} \hat{v}_id_i \right)^{-1}\sum_{i\in I} \hat{v}_i (y_i-\hat{g}_0(\boldsymbol{z}_i))

The estimator above is unbiased. However, you are probably wondering about efficiency loss because \hat{\theta} was estimated using only half of the sample. To solve this problem we must now do the opposite: first we estimate steps 1 and 2 in the sample I and then we estimate \theta in the sample I^c. As a result, we will have \hat{\theta}(I^c,I) and \hat{\theta}(I,I^c). The cross-fitting estimator will be:

\displaystyle \hat{\theta}_{CF}=\frac{\hat{\theta}(I^c,I)+\hat{\theta}(I,I^c)}{2}

which is a simple average between the two \thetas.


The best way to illustrate this topic is using simulation. I am going to generate data from the following model:

\displaystyle y_i=\theta d_i + cos^2(\boldsymbol{z}_i' b) + u_i

\displaystyle d_i = sin(\boldsymbol{z}_i'b)+cos(\boldsymbol{z}_i'b)+v_i

  • The number of observations and the number of simulations will be 500,
  • The number of variables in \boldsymbol{z}_i will be K=10, generated from a multivariate normal distribution,
  • \theta=0.5,
  • b_k=\frac{1}{k},~~~k=1,\dots,K,
  • u_i and v_i are normal with mean 0 and variance 1.

set.seed(123) # = Seed for Replication = #
N=500 # = Number of observations = #
k=10 # = Number of variables in z = #

# = Generate covariance matrix of z = #

The ML model we are going to use to estimate steps 1 and 2 is the Random Forest. The simulation will estimate the simple OLS using only d_i to explain y_i, the naive DML without sample splitting and the Cross-fitting DML. The 500 simulations may take a few minutes.

M=500 # = Number of Simumations = #

# = Matrix to store results = #
colnames(thetahat)=c("OLS","Naive DML","Cross-fiting DML")

for(i in 1:M){
  z=rmvnorm(N,sigma=sigma) # = Generate z = #
  g=as.vector(cos(z%*%b)^2) # = Generate the function g = #
  m=as.vector(sin(z%*%b)+cos(z%*%b)) # = Generate the function m = #
  d=m+rnorm(N) # = Generate d = #
  y=theta*d+g+rnorm(N) # = Generate y = #

  # = OLS estimate = #

  # = Naive DML = #
  # = Compute ghat = #
  model=randomForest(z,y,maxnodes = 20)
  # = Compute mhat = #
  modeld=randomForest(z,d,maxnodes = 20)
  # = compute vhat as the residuals of the second model = #
  # = Compute DML theta = #

  # = Cross-fitting DML = #
  # = Split sample = #
  # = compute ghat on both sample = #
  model1=randomForest(z[IC,],y[IC],maxnodes = 10)
  model2=randomForest(z[I,],y[I], maxnodes = 10)

  # = Compute mhat and vhat on both samples = #
  modeld1=randomForest(z[IC,],d[IC],maxnodes = 10)
  modeld2=randomForest(z[I,],d[I],maxnodes = 10)

  # = Compute Cross-Fitting DML theta


colMeans(thetahat) # = check the average theta for all models = #
##              OLS        Naive DML Cross-fiting DML
##        0.5465718        0.4155583        0.5065751
# = plot distributions = #
legend("topleft",legend=c("OLS","Naive DML","Cross-fiting DML"),col=c(1,2,4),lty=1,cex=0.7,seg.len = 0.7,bty="n")

plot of chunk unnamed-chunk-8

As you can see, the only unbiased distribution is the Cross-Fitting DML. However, the model used to estimate the functions m_0 and g_0 must be able to capture the relevant information for the data. In the linear case you may use the LASSO and achieve a result similar to the double selection. Finally, what we did here was a 2-fold Cross-Fitting, but you can also do a k-fold Cross-Fitting just like you do a k-fold cross-validation. The size of k does not matter asymptotically, but in small samples the results may change.


Chernozhukov, V., Chetverikov, D., Demirer, M., Duflo, E., & Hansen, C. (2016). Double machine learning for treatment and causal parameters. arXiv preprint arXiv:1608.00060.

To leave a comment for the author, please follow the link and comment on their blog: R – insightR. offers daily e-mail updates about R news and tutorials on topics such as: Data science, Big Data, R jobs, visualization (ggplot2, Boxplots, maps, animation), programming (RStudio, Sweave, LaTeX, SQL, Eclipse, git, hadoop, Web Scraping) statistics (regression, PCA, time series, trading) and more...

If you got this far, why not subscribe for updates from the site? Choose your flavor: e-mail, twitter, RSS, or facebook...

Comments are closed.

Search R-bloggers


Never miss an update!
Subscribe to R-bloggers to receive
e-mails with the latest R posts.
(You will not see this message again.)

Click here to close (This popup will not appear again)