Specifying Accelerated Failure Time Models in STAN

March 8, 2019
By

Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.

This post is an add-on to my previous post about augmented gibbs sampling for censored survival times. If you’re not a complete maniac like me, then you probably don’t want to code your own sampler from scratch like I did in that previous post. Luckily you don’t have to because you can easily specify that same model in Stan.

Let’s start with simulating some randomly censored data from a Weibull model. In this case, we just include a binary indicator and are interested in characterizing survival between these two groups.

set.seed(1)

n <- 1000

# simulate covariates (just a binary treatment indicator)
A <- rbinom(n, 1, .5)
X <- model.matrix(~ A)

# true parameters
true_beta <- (1/2)*matrix(c(-1/3, 2), ncol=1)
true_mu <- X %*% true_beta

true_sigma <- 1

true_alpha <- 1/true_sigma
true_lambda <- exp(-1*true_mu*true_alpha)

# simulate censoring and survival times
survt = rweibull(n, shape=true_alpha, scale = true_lambda)
cent = rweibull(n, shape=true_alpha, scale = true_lambda)

## observed data:
#censoring indicator
delta <- cent < survt
survt[delta==1] <- cent[delta==1] # censor survival time.

# count number of missing/censored survival times
n_miss <- sum(delta)

d_list <- list(N_m = n_miss, N_o = n - n_miss, P=2, # number of betas
# data for censored subjects
y_m=survt[delta==1], X_m=X[delta==1,],
# data for uncensored subjects
y_o=survt[delta==0], X_o=X[delta==0,])

The list d_list is what we’ll eventually feed to Stan. Below is the Stan model for Weibull distributed survival times. Note in the transformed parameters block we specify the canonical accelerated failure time (AFT) parameterization – modeling the scale as a function of the shape parameter, $$\alpha$$, and covariates.

In the model block, we specify the likelihood as the Weibull density for uncensored subjects, and then augment the likelihood with evaluations from the Weibull survival function (_lccdf).

The generated quantities block transforms the parameters to get posterior draws of the hazard ratio (as specified in my previous post ) as well as posterior draws of the survival function.

data {
int P; // number of beta parameters

// data for censored subjects
int N_m;
matrix[N_m,P] X_m;
vector[N_m] y_m;

// data for observed subjects
int N_o;
matrix[N_o,P] X_o;
real y_o[N_o];
}

parameters {
vector[P] beta;
real alpha; // Weibull Shape
}

transformed parameters{
// model Weibull rate as function of covariates
vector[N_m] lambda_m;
vector[N_o] lambda_o;

// standard weibull AFT re-parameterization
lambda_m = exp((X_m*beta)*alpha);
lambda_o = exp((X_o*beta)*alpha);
}

model {
beta ~ normal(0, 100);
alpha ~ exponential(1);

// evaluate likelihood for censored and uncensored subjects
target += weibull_lpdf(y_o | alpha, lambda_o);
target += weibull_lccdf(y_m | alpha, lambda_m);
}

// generate posterior quantities of interest
generated quantities{
vector[1000] post_pred_trt;
vector[1000] post_pred_pbo;
real lambda_trt;
real lambda_pbo;
real hazard_ratio;

// generate hazard ratio
lambda_trt = exp((beta[1] + beta[2])*alpha ) ;
lambda_pbo = exp((beta[1])*alpha ) ;

hazard_ratio = exp(beta[2]*alpha ) ;

// generate survival times (for plotting survival curves)
for(i in 1:1000){
post_pred_trt[i] = weibull_rng(alpha,  lambda_trt);
post_pred_pbo[i] = weibull_rng(alpha,  lambda_pbo);
}
}


The Stan model specified above is stored in an object called weibull_mod, which is called below in sampling(). The code below samples from the posterior and outputs posterior draws of the hazard and predicted survival times.

weibull_fit <- sampling(weibull_mod,
data = d_list,
chains = 1, iter=20000, warmup=19000, save_warmup=F,
pars= c('hazard_ratio','post_pred_trt','post_pred_pbo'))
##
## SAMPLING FOR MODEL '80acc0f9293b946800a710dd7f5e211c' NOW (CHAIN 1).
## Chain 1:
## Chain 1: Gradient evaluation took 0.000176 seconds
## Chain 1: 1000 transitions using 10 leapfrog steps per transition would take 1.76 seconds.
## Chain 1:
## Chain 1:
## Chain 1: Iteration:     1 / 20000 [  0%]  (Warmup)
## Chain 1: Iteration:  2000 / 20000 [ 10%]  (Warmup)
## Chain 1: Iteration:  4000 / 20000 [ 20%]  (Warmup)
## Chain 1: Iteration:  6000 / 20000 [ 30%]  (Warmup)
## Chain 1: Iteration:  8000 / 20000 [ 40%]  (Warmup)
## Chain 1: Iteration: 10000 / 20000 [ 50%]  (Warmup)
## Chain 1: Iteration: 12000 / 20000 [ 60%]  (Warmup)
## Chain 1: Iteration: 14000 / 20000 [ 70%]  (Warmup)
## Chain 1: Iteration: 16000 / 20000 [ 80%]  (Warmup)
## Chain 1: Iteration: 18000 / 20000 [ 90%]  (Warmup)
## Chain 1: Iteration: 19001 / 20000 [ 95%]  (Sampling)
## Chain 1: Iteration: 20000 / 20000 [100%]  (Sampling)
## Chain 1:
## Chain 1:  Elapsed Time: 9.09895 seconds (Warm-up)
## Chain 1:                0.736377 seconds (Sampling)
## Chain 1:                9.83533 seconds (Total)
## Chain 1:
post_draws<-extract(weibull_fit)

Below we plot posterior distribution of the hazard ratio. The red line indicates the true value under which we generated the data.

hist(post_draws$hazard_ratio, xlab='Hazard Ratio', main='Hazard Ratio Posterior Distribution') abline(v=exp(-1*true_beta[2,1]*true_alpha), col='red') mean(post_draws$hazard_ratio)
## [1] 0.3635223
quantile(post_draws$hazard_ratio, probs = c(.025, .975)) ## 2.5% 97.5% ## 0.2989080 0.4399663 Below we plot the survival functions. Note these results are very similar to the augmented sampler coded in the previous post. plot(survfit(Surv(survt, 1-delta) ~ A ), col=c('black','blue'), xlab='Time',ylab='Survival Probability', conf.int=T) for(i in 1:1000){ trt_ecdf <- ecdf(post_draws$post_pred_trt[i,])
curve(1 - trt_ecdf(x), from = 0, to=4, add=T, col='gray')

pbo_ecdf <- ecdf(post_draws\$post_pred_pbo[i,])
curve(1 - pbo_ecdf(x), from = 0, to=4, add=T, col='lightblue')
}

lines(survfit(Surv(survt, 1-delta) ~ A ), col=c('black','blue'), add=T,
conf.int=T)

legend('topright',
legend = c('KM Curve and Intervals (TRT)',
'Posterior Survival Draws (TRT)',
'KM Curve and Intervals (PBO)',
'Posterior Survival Draws (PBO)'),
col=c('black','gray','blue','lightblue'),
lty=c(1,0,1,0), pch=c(NA,15,NA,15), bty='n')

R-bloggers.com offers daily e-mail updates about R news and tutorials about learning R and many other topics. Click here if you're looking to post or find an R/data-science job.
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.