Cross-Fitting Double Machine Learning estimator

[This article was first published on R – insightR, and kindly contributed to R-bloggers]. (You can report issue about the content on this page here)
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.
By Gabriel Vasconcelos


In a late post I talked about inference after model selection showing that a simple double selection procedure is enough to solve the problem. In this post I’m going to talk about a generalization of the double selection for any Machine Learning (ML) method described by Chernozhukov et al. (2016).

Suppose you are in the following framework:

\displaystyle y_i=d_i\theta +g_0(\boldsymbol{z}_i)+u_i
\displaystyle d_i=m_0(\boldsymbol{z}_i)+v_i

where \theta is the parameter of interest, \boldsymbol{z}_i is a set of control variables and u_i and v_i are error terms. The functions m_0 and g_0 are very generic and possibly non-linear.

How can we estimate \theta? The most naive way (and obviously wrong) is to simple estimate a regression using only d_i to explain y_i. Another way is to try something similar to the double selection, which is referred by Chernozhukov et al. (2016) as Double Machine Learning (DML):

  • (1): Estimate d_i=\hat{m}_0(\boldsymbol{z}_i)+\hat{v}_i,
  • (2): Estimate y_i=\hat{g}_0(\boldsymbol{z}_i)+\hat{u}_i without using d_i,
  • (3): Estimate \hat{\theta}=(\sum_{i=1}^N \hat{v}_id_i)^{-1}\sum_{i=1}^N \hat{v}_i (y_i-\hat{g}_0(\boldsymbol{z}_i)).

However, in this case the procedure above will leave a term on the asymptotic distribution of \hat{\theta} that will cause bias. This problem may be solved with sample splitting. Suppose we randomly split our N observations in two samples of size n=N/2. The fist sample will be indexed by I and the auxiliary second sample will be indexed by I^c. We are going to estimate the first two steps above in the auxiliary sample I^c and the third step will be done into sample I. Therefore:

\displaystyle \hat{\theta}=\left(\sum_{i=\in I} \hat{v}_id_i \right)^{-1}\sum_{i\in I} \hat{v}_i (y_i-\hat{g}_0(\boldsymbol{z}_i))

The estimator above is unbiased. However, you are probably wondering about efficiency loss because \hat{\theta} was estimated using only half of the sample. To solve this problem we must now do the opposite: first we estimate steps 1 and 2 in the sample I and then we estimate \theta in the sample I^c. As a result, we will have \hat{\theta}(I^c,I) and \hat{\theta}(I,I^c). The cross-fitting estimator will be:

\displaystyle \hat{\theta}_{CF}=\frac{\hat{\theta}(I^c,I)+\hat{\theta}(I,I^c)}{2}

which is a simple average between the two \thetas.


The best way to illustrate this topic is using simulation. I am going to generate data from the following model:

\displaystyle y_i=\theta d_i + cos^2(\boldsymbol{z}_i' b) + u_i

\displaystyle d_i = sin(\boldsymbol{z}_i'b)+cos(\boldsymbol{z}_i'b)+v_i

  • The number of observations and the number of simulations will be 500,
  • The number of variables in \boldsymbol{z}_i will be K=10, generated from a multivariate normal distribution,
  • \theta=0.5,
  • b_k=\frac{1}{k},~~~k=1,\dots,K,
  • u_i and v_i are normal with mean 0 and variance 1.

set.seed(123) # = Seed for Replication = #
N=500 # = Number of observations = #
k=10 # = Number of variables in z = #

# = Generate covariance matrix of z = #

The ML model we are going to use to estimate steps 1 and 2 is the Random Forest. The simulation will estimate the simple OLS using only d_i to explain y_i, the naive DML without sample splitting and the Cross-fitting DML. The 500 simulations may take a few minutes.

M=500 # = Number of Simumations = #

# = Matrix to store results = #
colnames(thetahat)=c("OLS","Naive DML","Cross-fiting DML")

for(i in 1:M){
  z=rmvnorm(N,sigma=sigma) # = Generate z = #
  g=as.vector(cos(z%*%b)^2) # = Generate the function g = #
  m=as.vector(sin(z%*%b)+cos(z%*%b)) # = Generate the function m = #
  d=m+rnorm(N) # = Generate d = #
  y=theta*d+g+rnorm(N) # = Generate y = #

  # = OLS estimate = #

  # = Naive DML = #
  # = Compute ghat = #
  model=randomForest(z,y,maxnodes = 20)
  # = Compute mhat = #
  modeld=randomForest(z,d,maxnodes = 20)
  # = compute vhat as the residuals of the second model = #
  # = Compute DML theta = #

  # = Cross-fitting DML = #
  # = Split sample = #
  # = compute ghat on both sample = #
  model1=randomForest(z[IC,],y[IC],maxnodes = 10)
  model2=randomForest(z[I,],y[I], maxnodes = 10)

  # = Compute mhat and vhat on both samples = #
  modeld1=randomForest(z[IC,],d[IC],maxnodes = 10)
  modeld2=randomForest(z[I,],d[I],maxnodes = 10)

  # = Compute Cross-Fitting DML theta


colMeans(thetahat) # = check the average theta for all models = #

##              OLS        Naive DML Cross-fiting DML
##        0.5465718        0.4155583        0.5065751

# = plot distributions = #
legend("topleft",legend=c("OLS","Naive DML","Cross-fiting DML"),col=c(1,2,4),lty=1,cex=0.7,seg.len = 0.7,bty="n")

plot of chunk unnamed-chunk-8

As you can see, the only unbiased distribution is the Cross-Fitting DML. However, the model used to estimate the functions m_0 and g_0 must be able to capture the relevant information for the data. In the linear case you may use the LASSO and achieve a result similar to the double selection. Finally, what we did here was a 2-fold Cross-Fitting, but you can also do a k-fold Cross-Fitting just like you do a k-fold cross-validation. The size of k does not matter asymptotically, but in small samples the results may change.


Chernozhukov, V., Chetverikov, D., Demirer, M., Duflo, E., & Hansen, C. (2016). Double machine learning for treatment and causal parameters. arXiv preprint arXiv:1608.00060.

To leave a comment for the author, please follow the link and comment on their blog: R – insightR. offers daily e-mail updates about R news and tutorials about learning R and many other topics. Click here if you're looking to post or find an R/data-science job.
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.

Never miss an update!
Subscribe to R-bloggers to receive
e-mails with the latest R posts.
(You will not see this message again.)

Click here to close (This popup will not appear again)