The aim of this short tutorial is to demonstrate how
to implement causal machine learning estimators in practice.

We will focus on the average treatment effect (ATE) ATE = E(Y1)-E(Y0)

We will see the naive plug-in estimator (i.e. plug-in G-computation with data-adaptive estimates), the AIPW (corresponding to one-step and estimating equations) estimator and targeted maximum likelihood estimation (TMLE)

We do this using the Super Learner in R (for the data adaptive models).

We begin by loading the necessary libraries.

#' Install packages and load them  ##########
#install.packages("SuperLearner")
#install.packages("xgboost")
#install.packages("tmle")
#install.packages("devtools")
#library(devtools)
#install_github("ehkennedy/npcausal")
library(npcausal)
library(boot)
## Warning: package 'boot' was built under R version 3.6.2
library(MASS) 
## Warning: package 'MASS' was built under R version 3.6.2
library(SuperLearner)
## Loading required package: nnls
## Super Learner
## Version: 2.0-26
## Package created on 2019-10-27
## 
## Attaching package: 'SuperLearner'
## The following object is masked from 'package:npcausal':
## 
##     SL.ranger
library(survey)
## Warning: package 'survey' was built under R version 3.6.2
## Loading required package: grid
## Loading required package: Matrix
## Loading required package: survival
## Warning: package 'survival' was built under R version 3.6.2
## 
## Attaching package: 'survival'
## The following object is masked from 'package:boot':
## 
##     aml
## 
## Attaching package: 'survey'
## The following object is masked from 'package:graphics':
## 
##     dotchart
library(npcausal)
library(tmle)
## Warning: package 'tmle' was built under R version 3.6.2
## Loading required package: glmnet
## Warning: package 'glmnet' was built under R version 3.6.2
## Loaded glmnet 4.0-2
## Welcome to the tmle package, version 1.5.0-1
## 
## Major changes since v1.3.x. Use tmleNews() to see details on changes and bug fixes

The aim of this short tutorial is to demonstrate how
to implement causal machine learning estimators in practice.

We will focus on the average treatment effect (ATE) ATE = E(Y1)-E(Y0)

We will see the naive plug-in estimator (i.e. plug-in G-computation with data-adaptive estimates), the AIPW (corresponding to one-step and estimating equations) estimator and targeted maximum likelihood estimation (TMLE)

We do this using the Super Learner in R (for the data adaptive models).

We will use simulated data: where the Binary outcome is Y and treatment A, sample size n=1000 dim W =4 as variables we adjust to control for confounding. The following code generates the data

set.seed(129)
n=1000
w1 <- rbinom(n, size=1, prob=0.5)
w2 <- rbinom(n, size=1, prob=0.65)
w3 <- round(runif(n, min=0, max=4), digits=3)
w4 <- round(runif(n, min=0, max=5), digits=3)
A <- rbinom(n, size=1,
           prob= plogis(-0.4 + 0.2*w2 + 0.15*w3 + 0.2*w4 + 0.15*w2*w4))
Y <- rbinom(n, size=1,
           prob= plogis(-1 + A -0.1*w1 + 0.3*w2 + 0.25*w3 + 0.2*w4 + 0.15*w2*w4))
Y.1 <-  plogis( -0.1*w1 + 0.3*w2 + 0.25*w3 + 0.2*w4 + 0.15*w2*w4)
Y.0 <-  plogis(-1  -0.1*w1 + 0.3*w2 + 0.25*w3 + 0.2*w4 + 0.15*w2*w4)
trueATE<-mean(Y.1)-mean(Y.0)
trueATE
## [1] 0.1959683
#Create data frame with baseline covariates
W<-data.frame(cbind(w1,w2,w3,w4))
data<-data.frame(cbind(W,A,Y))

Super Learner

First, check which learners have been integrated into the SuperLearner package. We can use any of these when we run the SuperLearner:

library(SuperLearner)
listWrappers(what = "SL")
## All prediction algorithm wrappers in SuperLearner:
##  [1] "SL.bartMachine"      "SL.bayesglm"         "SL.biglasso"        
##  [4] "SL.caret"            "SL.caret.rpart"      "SL.cforest"         
##  [7] "SL.earth"            "SL.extraTrees"       "SL.gam"             
## [10] "SL.gbm"              "SL.glm"              "SL.glm.interaction" 
## [13] "SL.glmnet"           "SL.ipredbagg"        "SL.kernelKnn"       
## [16] "SL.knn"              "SL.ksvm"             "SL.lda"             
## [19] "SL.leekasso"         "SL.lm"               "SL.loess"           
## [22] "SL.logreg"           "SL.mean"             "SL.nnet"            
## [25] "SL.nnls"             "SL.polymars"         "SL.qda"             
## [28] "SL.randomForest"     "SL.ranger"           "SL.ridge"           
## [31] "SL.rpart"            "SL.rpartPrune"       "SL.speedglm"        
## [34] "SL.speedlm"          "SL.step"             "SL.step.forward"    
## [37] "SL.step.interaction" "SL.stepAIC"          "SL.svm"             
## [40] "SL.template"         "SL.xgboost"

Here we will use the following learners (as specified in the lecture)

SL.library<- c("SL.glm", "SL.glm.interaction", "SL.xgboost", "SL.glmnet", "SL.ranger")

These should ideally be tested with multiple hyperparameter settings for each algorithm which can be tuned using CV.

In the interest of time, now we only use the defaults. Make sure you check which parameters are this for each learner, by typing its name and checking the default options pre-programmed in the SL wrapper, for example, for random forests using the ranger implementation

SL.ranger
## function (Y, X, newX, family, obsWeights, num.trees = 500, mtry = floor(sqrt(ncol(X))), 
##     write.forest = TRUE, probability = family$family == "binomial", 
##     min.node.size = ifelse(family$family == "gaussian", 5, 1), 
##     replace = TRUE, sample.fraction = ifelse(replace, 1, 0.632), 
##     num.threads = 1, verbose = T, ...) 
## {
##     .SL.require("ranger")
##     if (family$family == "binomial") {
##         Y = as.factor(Y)
##     }
##     if (is.matrix(X)) {
##         X = data.frame(X)
##     }
##     fit <- ranger::ranger(`_Y` ~ ., data = cbind(`_Y` = Y, X), 
##         num.trees = num.trees, mtry = mtry, min.node.size = min.node.size, 
##         replace = replace, sample.fraction = sample.fraction, 
##         case.weights = obsWeights, write.forest = write.forest, 
##         probability = probability, num.threads = num.threads, 
##         verbose = verbose)
##     pred <- predict(fit, data = newX)$predictions
##     if (family$family == "binomial") {
##         pred = pred[, "1"]
##     }
##     fit <- list(object = fit, verbose = verbose)
##     class(fit) <- c("SL.ranger")
##     out <- list(pred = pred, fit = fit)
##     return(out)
## }
## <bytecode: 0x7f8742b1c170>
## <environment: namespace:SuperLearner>

We see that the number of trees is 500 and the number of variables to consider for each tree is the sqrt of the number of total independent variables (sqrt (dimW)) rounded down to the next lower interger.

SL for the outcome regression (naive plug-in g-computation)

SL.outcome<- SuperLearner(Y=data$Y, X=subset(data, select=-Y),
                                     SL.library=SL.library, family="binomial")
## Loading required namespace: xgboost
## Loading required namespace: ranger

#’ You can look at the Super learner object, to see how the alogorithms are weighted

SL.outcome
## 
## Call:  
## SuperLearner(Y = data$Y, X = subset(data, select = -Y), family = "binomial",  
##     SL.library = SL.library) 
## 
## 
##                             Risk       Coef
## SL.glm_All             0.1713686 0.90571079
## SL.glm.interaction_All 0.1735955 0.09428921
## SL.xgboost_All         0.2070323 0.00000000
## SL.glmnet_All          0.1714559 0.00000000
## SL.ranger_All          0.1817116 0.00000000

Now we get the prediction for the actual exposure level received and the two potential outcomes for everyone, based on the trained SL

SL.outcome.obs<- predict(SL.outcome, newdata=subset(data, select=-Y))$pred
# predict the PO Y^1
SL.outcome.exp<- predict(SL.outcome, newdata=data.frame(cbind(W,A=rep(1,length(A)))))$pred
# predict the PO Y^0
SL.outcome.unexp<- predict(SL.outcome, newdata=data.frame(cbind(W,A=rep(0,length(A)))))$pred

SL g-computation

We can now use these two predictions to get the plug-in g-somputation

SL.plugin.gcomp<-mean(SL.outcome.exp-SL.outcome.unexp)
SL.plugin.gcomp
## [1] 0.1999569

Warning: no way of doing inference, bootstrap not valid when using ML

We collate the SL fits, because we’re going to use them later

Q=cbind(SL.outcome.obs, SL.outcome.unexp,SL.outcome.exp)
colnames(Q)<-c("QAW","Q0W","Q1W")

plug-in AIPW

Now, we will use the outcome predictions and the propensity score predictions to estimate an AIPW with SL plog-ins.

First the SL for the prop score

SL.g<- SuperLearner(Y=data$A, X=subset(data, select=-c(A,Y)),
                    SL.library=SL.library, family="binomial")

#’ You can look at the Super learner object, to see how the alogorithms are weighted

SL.g
## 
## Call:  
## SuperLearner(Y = data$A, X = subset(data, select = -c(A, Y)), family = "binomial",  
##     SL.library = SL.library) 
## 
## 
##                             Risk       Coef
## SL.glm_All             0.1997010 0.42764584
## SL.glm.interaction_All 0.2015690 0.00000000
## SL.xgboost_All         0.2377393 0.00000000
## SL.glmnet_All          0.1996967 0.50222877
## SL.ranger_All          0.2146869 0.07012538

We see that here all the learners have non-zero coefficients for the SL.

Now, get the probability of getting the exposure

g1W <- SL.g$SL.predict
summary(g1W)
##        V1        
##  Min.   :0.3966  
##  1st Qu.:0.6437  
##  Median :0.7137  
##  Mean   :0.7050  
##  3rd Qu.:0.7828  
##  Max.   :0.8926
# Look at the histogram of PS
hist(g1W)

# Look at the histogram of the weights. 
hist(1/g1W)

For any real analysis, you must satisfy yourself that the positivity assumption holds, so that the weights are not “too” large.

Now the probability of being unexposed

g0W<- 1- g1W

We can now use these quantities to estimate the mean of the potential outcomes, and thus, the ATE, based on the IF shown in the lecture.
The IF for the AIPW of the Y^1 and the Y^0 can be written

IF.1<-((data$A/g1W)*(data$Y-Q[,"Q1W"])+Q[,"Q1W"])
IF.0<-(((1-data$A)/g0W)*(data$Y-Q[,"Q0W"])+Q[,"Q0W"])
#The IF of the ATE is then
IF<-IF.1-IF.0

We saw that the estimating eq. estimator of ATE=mean(IF)

aipw.1<-mean(IF.1);aipw.0<-mean(IF.0)
aipw.manual<-aipw.1-aipw.0

We now now that this estimator is asymp Normally distributed and its variance is var(IF)/n

ci.lb<-mean(IF)-qnorm(.975)*sd(IF)/sqrt(length(IF))
ci.ub<-mean(IF)+qnorm(.975)*sd(IF)/sqrt(length(IF))
 res.manual.aipw<-c(aipw.manual,ci.lb, ci.ub)
res.manual.aipw
## [1] 0.1992587 0.1414578 0.2570596

AIPW using the package npcausal

Now that you see how the concept works, you can use the npcausal package, which has pre-programed this, and other estimands.

For now, we specify no sample splitting

library(npcausal)
aipw<- ate(y=Y, a=A, x=W, nsplits=1, sl.lib=c("SL.glm", "SL.glm.interaction", "SL.glmnet", "SL.ranger"))
## Loading required package: earth
## Warning in library(package, lib.loc = lib.loc, character.only = TRUE,
## logical.return = TRUE, : there is no package called 'earth'
## Loading required package: gam
## Loading required package: splines
## Loading required package: foreach
## Warning: package 'foreach' was built under R version 3.6.2
## Loaded gam 1.16.1
## Loading required package: ranger
## Loading required package: rpart
## 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |==================                                                    |  25%
  |                                                                            
  |===================================                                   |  50%
  |                                                                            
  |====================================================                  |  75%
  |                                                                            
  |======================================================================| 100%
##      parameter       est         se     ci.ll     ci.ul pval
## 1      E{Y(0)} 0.5941452 0.02492451 0.5452932 0.6429973    0
## 2      E{Y(1)} 0.7933962 0.01564994 0.7627223 0.8240701    0
## 3 E{Y(1)-Y(0)} 0.1992510 0.02888682 0.1426328 0.2558691    0
aipw$res
##      parameter       est         se     ci.ll     ci.ul pval
## 1      E{Y(0)} 0.5941452 0.02492451 0.5452932 0.6429973    0
## 2      E{Y(1)} 0.7933962 0.01564994 0.7627223 0.8240701    0
## 3 E{Y(1)-Y(0)} 0.1992510 0.02888682 0.1426328 0.2558691    0

TMLE

We now move on to the TMLE for the ATE. Using the following code you can implement a tmle by hand, based on the clever covariate approach you saw on the first session

# First E(Y1)
#' Constructing the clever covariate
H<-as.numeric(data$A/g1W)

We now fit a parametric model, with the clever covariate the only explanatory variable, and using the initial outcome predictions as an offset

model<-glm(data$Y~-1+H+offset(qlogis(Q[,"QAW"])),family=binomial)
summary(model)
## 
## Call:
## glm(formula = data$Y ~ -1 + H + offset(qlogis(Q[, "QAW"])), family = binomial)
## 
## Deviance Residuals: 
##     Min       1Q   Median       3Q      Max  
## -2.5238  -0.8974   0.5221   0.7606   1.6624  
## 
## Coefficients:
##     Estimate Std. Error z value Pr(>|z|)
## H -0.0008892  0.0667585  -0.013    0.989
## 
## (Dispersion parameter for binomial family taken to be 1)
## 
##     Null deviance: 1016.8  on 1000  degrees of freedom
## Residual deviance: 1016.8  on  999  degrees of freedom
## AIC: 1018.8
## 
## Number of Fisher Scoring iterations: 4

We update the initial predictions using the coefficient of the clever covariate

Q1W.1<-plogis(qlogis(Q[,"Q1W"])+coef(model)[1]/g1W)

And use this to get the TMLE estimate of the mean of Y^1

# Estimating E(Y1)
mean(Q1W.1)
## [1] 0.7933071

We now repeat for Y^0

# E(Y0)
# Constructing the clever covariate
H<-as.numeric((1-data$A)/g0W)
# Fitting a parametric extension model
model<-glm(data$Y~-1+H+offset(qlogis(Q[,"QAW"])),family=binomial)
summary(model)
## 
## Call:
## glm(formula = data$Y ~ -1 + H + offset(qlogis(Q[, "QAW"])), family = binomial)
## 
## Deviance Residuals: 
##     Min       1Q   Median       3Q      Max  
## -2.5242  -0.8980   0.5218   0.7602   1.6617  
## 
## Coefficients:
##    Estimate Std. Error z value Pr(>|z|)
## H 0.0007479  0.0387300   0.019    0.985
## 
## (Dispersion parameter for binomial family taken to be 1)
## 
##     Null deviance: 1016.8  on 1000  degrees of freedom
## Residual deviance: 1016.8  on  999  degrees of freedom
## AIC: 1018.8
## 
## Number of Fisher Scoring iterations: 3
# Updating the predictions
Q0W.1<-plogis(qlogis(Q[,"Q0W"])+coef(model)[1]/g0W)
# Estimating E(Y0)
mean(Q0W.1)
## [1] 0.5941483

And put together to get the TMLE for the ATE

# ATE = E(Y1)-E(Y0)
TMLE.1 =mean(Q1W.1)-mean(Q0W.1)

You can do all of this automatically using the tmle package, which also has coded other estimands. Other TMLE packages exists for other common estimands, such as mediation, IV regression or longitudinal settings

TMLE using the R package

library(tmle)
TMLE<- tmle(Y=data$Y,A=data$A,W=subset(data, select=-c(A,Y)), family="binomial", Q.SL.library=SL.library, g.SL.library=SL.library)

TMLE$estimates$ATE
## $psi
## [1] 0.196299
## 
## $var.psi
## [1] 0.0009300406
## 
## $CI
## [1] 0.1365258 0.2560723
## 
## $pvalue
## [1] 1.22052e-10

Cross-fitting

It turns out that to remove further bias, while avoiding extra assumptions, we should use sample splitting. Even better, we should use cross-fitting. This can be done relatively easily in the npcausal package

aipw.2<- ate(y=Y, a=A, x=W, nsplits=10, sl.lib=c("SL.glm", "SL.glm.interaction", "SL.glmnet", "SL.ranger"))
## Loading required package: earth
## Warning in library(package, lib.loc = lib.loc, character.only = TRUE,
## logical.return = TRUE, : there is no package called 'earth'
## 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |==                                                                    |   2%
  |                                                                            
  |====                                                                  |   5%
  |                                                                            
  |=====                                                                 |   8%
  |                                                                            
  |=======                                                               |  10%
  |                                                                            
  |=========                                                             |  12%
  |                                                                            
  |==========                                                            |  15%
  |                                                                            
  |============                                                          |  18%
  |                                                                            
  |==============                                                        |  20%
  |                                                                            
  |================                                                      |  22%
  |                                                                            
  |==================                                                    |  25%
  |                                                                            
  |===================                                                   |  28%
  |                                                                            
  |=====================                                                 |  30%
  |                                                                            
  |=======================                                               |  32%
  |                                                                            
  |========================                                              |  35%
  |                                                                            
  |==========================                                            |  38%
  |                                                                            
  |============================                                          |  40%
  |                                                                            
  |==============================                                        |  42%
  |                                                                            
  |================================                                      |  45%
  |                                                                            
  |=================================                                     |  48%
  |                                                                            
  |===================================                                   |  50%
  |                                                                            
  |=====================================                                 |  52%
  |                                                                            
  |======================================                                |  55%
  |                                                                            
  |========================================                              |  58%
  |                                                                            
  |==========================================                            |  60%
  |                                                                            
  |============================================                          |  62%
  |                                                                            
  |==============================================                        |  65%
  |                                                                            
  |===============================================                       |  68%
  |                                                                            
  |=================================================                     |  70%
  |                                                                            
  |===================================================                   |  72%
  |                                                                            
  |====================================================                  |  75%
  |                                                                            
  |======================================================                |  78%
  |                                                                            
  |========================================================              |  80%
  |                                                                            
  |==========================================================            |  82%
  |                                                                            
  |============================================================          |  85%
  |                                                                            
  |=============================================================         |  88%
  |                                                                            
  |===============================================================       |  90%
  |                                                                            
  |=================================================================     |  92%
  |                                                                            
  |==================================================================    |  95%
  |                                                                            
  |====================================================================  |  98%
  |                                                                            
  |======================================================================| 100%
##      parameter       est         se     ci.ll     ci.ul pval
## 1      E{Y(0)} 0.5954396 0.02830390 0.5399639 0.6509152    0
## 2      E{Y(1)} 0.7933503 0.01574690 0.7624863 0.8242142    0
## 3 E{Y(1)-Y(0)} 0.1979107 0.03190509 0.1353767 0.2604447    0
aipw.2$res
##      parameter       est         se     ci.ll     ci.ul pval
## 1      E{Y(0)} 0.5954396 0.02830390 0.5399639 0.6509152    0
## 2      E{Y(1)} 0.7933503 0.01574690 0.7624863 0.8242142    0
## 3 E{Y(1)-Y(0)} 0.1979107 0.03190509 0.1353767 0.2604447    0

You should also check tmle3, the newest implmentation of TMLE, where the default option is to fit a CV-TMLE https://tlverse.org/tlverse-handbook/tmle3.html

Remember when doing your own analyses, to tune your learners. To learn how to do this using the SL, visit https://cran.r-project.org/web/packages/SuperLearner/vignettes/Guide-to-SuperLearner.html

Further reading

Causal machine learning

For more general reading on debiased machine learning and tmle, see

  • van der Laan, M. J. and Rose, S. (2011). Targeted Learning. Springer Series in Statistics. Springer New York, New York, NY

  • Chernozhukov, V., Chetverikov, D., Demirer, M., Dufflo, E., Hansen, C., Newey, W., and Robins, J. (2018). Double/debiased machine learning for treatment and structural parameters. The Econometrics Journal, 21(1):C1{ C68.

Influence functions

  • Fisher, A., & Kennedy, E. H. (2020). Visually communicating and teaching intuition for influence functions. The American Statistician, 1-11.

  • Levy, J. (2019). Tutorial: Deriving The Efficient Influence Curve for Large Models. https://arxiv.org/abs/1903.01706

  • Ichimura, H. and Newey, W. K. (2015). The influence function of semiparametric estimators. arXiv preprint arXiv:1508.01378

R Software packages