Can we predict flu deaths with Machine Learning and R?
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.
Among the many R packages, there is the outbreaks package. It contains datasets on epidemics, on of which is from the 2013 outbreak of influenza A H7N9 in China, as analysed by Kucharski et al. (2014):
A. Kucharski, H. Mills, A. Pinsent, C. Fraser, M. Van Kerkhove, C. A. Donnelly, and S. Riley. 2014. Distinguishing between reservoir exposure and human-to-human transmission for emerging pathogens using case onset data. PLOS Currents Outbreaks. Mar 7, edition 1. doi: 10.1371/currents.outbreaks.e1473d9bfc99d080ca242139a06c455f.
A. Kucharski, H. Mills, A. Pinsent, C. Fraser, M. Van Kerkhove, C. A. Donnelly, and S. Riley. 2014. Data from: Distinguishing between reservoir exposure and human-to-human transmission for emerging pathogens using case onset data. Dryad Digital Repository. http://dx.doi.org/10.5061/dryad.2g43n.
I will be using their data as an example to test whether we can use Machine Learning algorithms for predicting disease outcome.
Disclaimer: I am not an expert in Machine Learning. Everything I know, I taught myself. So, if you identify any mistakes or have tips and tricks for improvement, please don’t hesitate to let me know! Thanks. 🙂
The data
The dataset contains case ID, date of onset, date of hospitalisation, date of outcome, gender, age, province and of course the outcome: Death or Recovery. I can already see that there are a couple of missing values in the data, which I will deal with later.
# install and load package
if (!require("outbreaks")) install.packages("outbreaks")
library(outbreaks)
fluH7N9.china.2013_backup <- fluH7N9.china.2013 # back up original dataset in case something goes awry along the way
# convert ? to NAs
fluH7N9.china.2013$age[which(fluH7N9.china.2013$age == "?")] <- NA
# create a new column with case ID
fluH7N9.china.2013$case.ID <- paste("case", fluH7N9.china.2013$case.ID, sep = "_")
head(fluH7N9.china.2013)
## case.ID date.of.onset date.of.hospitalisation date.of.outcome outcome gender age province
## 1 case_1 2013-02-19 <NA> 2013-03-04 Death m 87 Shanghai
## 2 case_2 2013-02-27 2013-03-03 2013-03-10 Death m 27 Shanghai
## 3 case_3 2013-03-09 2013-03-19 2013-04-09 Death f 35 Anhui
## 4 case_4 2013-03-19 2013-03-27 <NA> <NA> f 45 Jiangsu
## 5 case_5 2013-03-19 2013-03-30 2013-05-15 Recover f 48 Jiangsu
## 6 case_6 2013-03-21 2013-03-28 2013-04-26 Death f 32 Jiangsu
Before I start preparing the data for Machine Learning, I want to get an idea of the distribution of the data points and their different variables by plotting.
Most provinces have only a handful of cases, so I am combining them into the category “other” and keep only Jiangsu, Shanghai and Zhejian and separate provinces.
# gather for plotting with ggplot2
library(tidyr)
fluH7N9.china.2013_gather <- fluH7N9.china.2013 %>%
gather(Group, Date, date.of.onset:date.of.outcome)
# rearrange group order
fluH7N9.china.2013_gather$Group <- factor(fluH7N9.china.2013_gather$Group, levels = c("date.of.onset", "date.of.hospitalisation", "date.of.outcome"))
# rename groups
library(plyr)
fluH7N9.china.2013_gather$Group <- mapvalues(fluH7N9.china.2013_gather$Group, from = c("date.of.onset", "date.of.hospitalisation", "date.of.outcome"),
to = c("Date of onset", "Date of hospitalisation", "Date of outcome"))
# renaming provinces
fluH7N9.china.2013_gather$province <- mapvalues(fluH7N9.china.2013_gather$province,
from = c("Anhui", "Beijing", "Fujian", "Guangdong", "Hebei", "Henan", "Hunan", "Jiangxi", "Shandong", "Taiwan"),
to = rep("Other", 10))
# add a level for unknown gender
levels(fluH7N9.china.2013_gather$gender) <- c(levels(fluH7N9.china.2013_gather$gender), "unknown")
fluH7N9.china.2013_gather$gender[is.na(fluH7N9.china.2013_gather$gender)] <- "unknown"
# rearrange province order so that Other is the last
fluH7N9.china.2013_gather$province <- factor(fluH7N9.china.2013_gather$province, levels = c("Jiangsu", "Shanghai", "Zhejiang", "Other"))
# preparing my ggplot2 theme of choice
library(ggplot2)
my_theme <- function(base_size = 12, base_family = "sans"){
theme_minimal(base_size = base_size, base_family = base_family) +
theme(
axis.text = element_text(size = 12),
axis.text.x = element_text(angle = 45, vjust = 0.5, hjust = 0.5),
axis.title = element_text(size = 14),
panel.grid.major = element_line(color = "grey"),
panel.grid.minor = element_blank(),
panel.background = element_rect(fill = "aliceblue"),
strip.background = element_rect(fill = "lightgrey", color = "grey", size = 1),
strip.text = element_text(face = "bold", size = 12, color = "black"),
legend.position = "bottom",
legend.justification = "top",
legend.box = "horizontal",
legend.box.background = element_rect(colour = "grey50"),
legend.background = element_blank(),
panel.border = element_rect(color = "grey", fill = NA, size = 0.5)
)
}
# plotting raw data
ggplot(data = fluH7N9.china.2013_gather, aes(x = Date, y = as.numeric(age), fill = outcome)) +
stat_density2d(aes(alpha = ..level..), geom = "polygon") +
geom_jitter(aes(color = outcome, shape = gender), size = 1.5) +
geom_rug(aes(color = outcome)) +
labs(
fill = "Outcome",
color = "Outcome",
alpha = "Level",
shape = "Gender",
x = "Date in 2013",
y = "Age",
title = "2013 Influenza A H7N9 cases in China",
subtitle = "Dataset from 'outbreaks' package (Kucharski et al. 2014)",
caption = ""
) +
facet_grid(Group ~ province) +
my_theme() +
scale_shape_manual(values = c(15, 16, 17)) +
scale_color_brewer(palette="Set1", na.value = "grey50") +
scale_fill_brewer(palette="Set1")
This plot shows the dates of onset, hospitalisation and outcome (if known) of each data point. Outcome is marked by color and age shown on the y-axis. Gender is marked by point shape.
The density distribution of date by age for the cases seems to indicate that older people died more frequently in the Jiangsu and Zhejiang province than in Shanghai and in other provinces.
When we look at the distribution of points along the time axis, it suggests that their might be positive correlation between the likelihood of death and an early onset or early outcome.
I also want to know how many cases there are for each gender and province and compare the genders’ age distribution.
fluH7N9.china.2013_gather_2 <- fluH7N9.china.2013_gather[, -4] %>%
gather(group_2, value, gender:province)
fluH7N9.china.2013_gather_2$value <- mapvalues(fluH7N9.china.2013_gather_2$value, from = c("m", "f", "unknown", "Other"),
to = c("Male", "Female", "Unknown gender", "Other province"))
fluH7N9.china.2013_gather_2$value <- factor(fluH7N9.china.2013_gather_2$value,
levels = c("Female", "Male", "Unknown gender", "Jiangsu", "Shanghai", "Zhejiang", "Other province"))
p1 <- ggplot(data = fluH7N9.china.2013_gather_2, aes(x = value, fill = outcome, color = outcome)) +
geom_bar(position = "dodge", alpha = 0.7, size = 1) +
my_theme() +
scale_fill_brewer(palette="Set1", na.value = "grey50") +
scale_color_brewer(palette="Set1", na.value = "grey50") +
labs(
color = "",
fill = "",
x = "",
y = "Count",
title = "2013 Influenza A H7N9 cases in China",
subtitle = "Gender and Province numbers of flu cases",
caption = ""
)
p2 <- ggplot(data = fluH7N9.china.2013_gather, aes(x = as.numeric(age), fill = outcome, color = outcome)) +
geom_density(alpha = 0.3, size = 1) +
geom_rug() +
scale_color_brewer(palette="Set1", na.value = "grey50") +
scale_fill_brewer(palette="Set1", na.value = "grey50") +
my_theme() +
labs(
color = "",
fill = "",
x = "Age",
y = "Density",
title = "",
subtitle = "Age distribution of flu cases",
caption = ""
)
library(gridExtra)
library(grid)
grid.arrange(p1, p2, ncol = 2)
In the dataset, there are more male than female cases and correspondingly, we see more deaths, recoveries and unknown outcomes in men than in women. This is potentially a problem later on for modeling because the inherent likelihoods for outcome are not directly comparable between the sexes.
Most unknown outcomes were recorded in Zhejiang. Similarly to gender, we don’t have an equal distribution of data points across provinces either.
When we look at the age distribution it is obvious that people who died tended to be slightly older than those who recovered. The density curve of unknown outcomes is more similar to that of death than of recovery, suggesting that among these people there might have been more deaths than recoveries.
And lastly, I want to plot how many days passed between onset, hospitalisation and outcome for each case.
ggplot(data = fluH7N9.china.2013_gather, aes(x = Date, y = as.numeric(age), color = outcome)) +
geom_point(aes(shape = gender), size = 1.5, alpha = 0.6) +
geom_path(aes(group = case.ID)) +
facet_wrap( ~ province, ncol = 2) +
my_theme() +
scale_shape_manual(values = c(15, 16, 17)) +
scale_color_brewer(palette="Set1", na.value = "grey50") +
scale_fill_brewer(palette="Set1") +
labs(
color = "Outcome",
shape = "Gender",
x = "Date in 2013",
y = "Age",
title = "2013 Influenza A H7N9 cases in China",
subtitle = "Dataset from 'outbreaks' package (Kucharski et al. 2014)",
caption = "\nTime from onset of flu to outcome."
)
This plot shows that there are many missing values in the dates, so it is hard to draw a general conclusion.
Features
In Machine Learning-speak features are the variables used for model training. Using the right features dramatically influences the accuracy of the model.
Because we don’t have many features, I am keeping age as it is, but I am also generating new features:
- from the date information I am calculating the days between onset and outcome and between onset and hospitalisation
- I am converting gender into numeric values with 1 for female and 0 for male
- similarly, I am converting provinces to binary classifiers (yes == 1, no == 0) for Shanghai, Zhejiang, Jiangsu and other provinces
- the same binary classification is given for whether a case was hospitalised, and whether they had an early onset or early outcome (earlier than the median date)
# preparing the data frame for modeling
#
library(dplyr)
dataset <- fluH7N9.china.2013 %>%
mutate(hospital = as.factor(ifelse(is.na(date.of.hospitalisation), 0, 1)),
gender_f = as.factor(ifelse(gender == "f", 1, 0)),
province_Jiangsu = as.factor(ifelse(province == "Jiangsu", 1, 0)),
province_Shanghai = as.factor(ifelse(province == "Shanghai", 1, 0)),
province_Zhejiang = as.factor(ifelse(province == "Zhejiang", 1, 0)),
province_other = as.factor(ifelse(province == "Zhejiang" | province == "Jiangsu" | province == "Shanghai", 0, 1)),
days_onset_to_outcome = as.numeric(as.character(gsub(" days", "",
as.Date(as.character(date.of.outcome), format = "%Y-%m-%d") -
as.Date(as.character(date.of.onset), format = "%Y-%m-%d")))),
days_onset_to_hospital = as.numeric(as.character(gsub(" days", "",
as.Date(as.character(date.of.hospitalisation), format = "%Y-%m-%d") -
as.Date(as.character(date.of.onset), format = "%Y-%m-%d")))),
age = as.numeric(as.character(age)),
early_onset = as.factor(ifelse(date.of.onset < summary(fluH7N9.china.2013$date.of.onset)[[3]], 1, 0)),
early_outcome = as.factor(ifelse(date.of.outcome < summary(fluH7N9.china.2013$date.of.outcome)[[3]], 1, 0))) %>%
subset(select = -c(2:4, 6, 8))
rownames(dataset) <- dataset$case.ID
dataset <- dataset[, -1]
head(dataset)
## outcome age hospital gender_f province_Jiangsu province_Shanghai province_Zhejiang province_other days_onset_to_outcome days_onset_to_hospital early_onset early_outcome
## case_1 Death 87 0 0 0 1 0 0 13 NA 1 1
## case_2 Death 27 1 0 0 1 0 0 11 4 1 1
## case_3 Death 35 1 1 0 0 0 1 31 10 1 1
## case_4 <NA> 45 1 1 1 0 0 0 NA 8 1 <NA>
## case_5 Recover 48 1 1 1 0 0 0 57 11 1 0
## case_6 Death 32 1 1 1 0 0 0 36 7 1 1
Imputing missing values
When looking at the dataset I created for modeling, it is obvious that we have quite a few missing values.
The missing values from the outcome column are what I want to predict but for the rest I would either have to remove the entire row from the data or impute the missing information. I decided to try the latter with the mice package.
# impute missing data
library(mice)
dataset_impute <- mice(dataset[, -1], print = FALSE)
dataset_impute
## Multiply imputed dataset
## Call:
## mice(data = dataset[, -1], printFlag = FALSE)
## Number of multiple imputations: 5
## Missing cells per column:
## age hospital gender_f province_Jiangsu province_Shanghai province_Zhejiang province_other days_onset_to_outcome days_onset_to_hospital early_onset early_outcome
## 2 0 2 0 0 0 0 67 74 10 65
## Imputation methods:
## age hospital gender_f province_Jiangsu province_Shanghai province_Zhejiang province_other days_onset_to_outcome days_onset_to_hospital early_onset early_outcome
## "pmm" "" "logreg" "" "" "" "" "pmm" "pmm" "logreg" "logreg"
## VisitSequence:
## age gender_f days_onset_to_outcome days_onset_to_hospital early_onset early_outcome
## 1 3 8 9 10 11
## PredictorMatrix:
## age hospital gender_f province_Jiangsu province_Shanghai province_Zhejiang province_other days_onset_to_outcome days_onset_to_hospital early_onset early_outcome
## age 0 1 1 1 1 1 1 1 1 1 1
## hospital 0 0 0 0 0 0 0 0 0 0 0
## gender_f 1 1 0 1 1 1 1 1 1 1 1
## province_Jiangsu 0 0 0 0 0 0 0 0 0 0 0
## province_Shanghai 0 0 0 0 0 0 0 0 0 0 0
## province_Zhejiang 0 0 0 0 0 0 0 0 0 0 0
## province_other 0 0 0 0 0 0 0 0 0 0 0
## days_onset_to_outcome 1 1 1 1 1 1 1 0 1 1 1
## days_onset_to_hospital 1 1 1 1 1 1 1 1 0 1 1
## early_onset 1 1 1 1 1 1 1 1 1 0 1
## early_outcome 1 1 1 1 1 1 1 1 1 1 0
## Random generator seed value: NA
# recombine imputed data frame with the outcome column
dataset_complete <- merge(dataset[, 1, drop = FALSE], mice::complete(dataset_impute, 1), by = "row.names", all = TRUE)
rownames(dataset_complete) <- dataset_complete$Row.names
dataset_complete <- dataset_complete[, -1]
Test, train and validation datasets
For building the model, I am separating the imputed data frame into training and test data. Test data are the 57 cases with unknown outcome.
summary(dataset$outcome)
## Death Recover NA's
## 32 47 57
The training data will be further devided for validation of the models: 70% of the training data will be kept for model building and the remaining 30% will be used for model testing.
I am using the caret package for modeling.
train_index <- which(is.na(dataset_complete$outcome))
train_data <- dataset_complete[-train_index, ]
test_data <- dataset_complete[train_index, -1]
library(caret)
set.seed(27)
val_index <- createDataPartition(train_data$outcome, p = 0.7, list=FALSE)
val_train_data <- train_data[val_index, ]
val_test_data <- train_data[-val_index, ]
val_train_X <- val_train_data[,-1]
val_test_X <- val_test_data[,-1]
Decision trees
To get an idea about how each feature contributes to the prediction of the outcome, I first built a decision tree based on the training data using rpart and rattle.
library(rpart)
library(rattle)
library(rpart.plot)
library(RColorBrewer)
set.seed(27)
fit <- rpart(outcome ~ .,
data = train_data,
method = "class",
control = rpart.control(xval = 10, minbucket = 2, cp = 0), parms = list(split = "information"))
fancyRpartPlot(fit)
This randomly generated decision tree shows that cases with an early outcome were most likely to die when they were 68 or older, when they also had an early onset and when they were sick for fewer than 13 days. If a person was not among the first cases and was younger than 52, they had a good chance of recovering, but if they were 82 or older, they were more likely to die from the flu.
Feature Importance
Not all of the features I created will be equally important to the model. The decision tree already gave me an idea of which features might be most important but I also want to estimate feature importance using a Random Forest approach with repeated cross validation.
# prepare training scheme
control <- trainControl(method = "repeatedcv", number = 10, repeats = 10)
# train the model
set.seed(27)
model <- train(outcome ~ ., data = train_data, method = "rf", preProcess = NULL, trControl = control)
# estimate variable importance
importance <- varImp(model, scale=TRUE)
# prepare for plotting
importance_df_1 <- importance$importance
importance_df_1$group <- rownames(importance_df_1)
importance_df_1$group <- mapvalues(importance_df_1$group,
from = c("age", "hospital2", "gender_f2", "province_Jiangsu2", "province_Shanghai2", "province_Zhejiang2",
"province_other2", "days_onset_to_outcome", "days_onset_to_hospital", "early_onset2", "early_outcome2" ),
to = c("Age", "Hospital", "Female", "Jiangsu", "Shanghai", "Zhejiang",
"Other province", "Days onset to outcome", "Days onset to hospital", "Early onset", "Early outcome" ))
f = importance_df_1[order(importance_df_1$Overall, decreasing = FALSE), "group"]
importance_df_2 <- importance_df_1
importance_df_2$Overall <- 0
importance_df <- rbind(importance_df_1, importance_df_2)
# setting factor levels
importance_df <- within(importance_df, group <- factor(group, levels = f))
importance_df_1 <- within(importance_df_1, group <- factor(group, levels = f))
ggplot() +
geom_point(data = importance_df_1, aes(x = Overall, y = group, color = group), size = 2) +
geom_path(data = importance_df, aes(x = Overall, y = group, color = group, group = group), size = 1) +
scale_color_manual(values = rep(brewer.pal(1, "Set1")[1], 11)) +
my_theme() +
theme(legend.position = "none",
axis.text.x = element_text(angle = 0, vjust = 0.5, hjust = 0.5)) +
labs(
x = "Importance",
y = "",
title = "2013 Influenza A H7N9 cases in China",
subtitle = "Scaled feature importance",
caption = "\nDetermined with Random Forest and
repeated cross validation (10 repeats, 10 times)"
)
This tells me that age is the most important determining factor for predicting disease outcome, followed by days between onset an outcome, early outcome and days between onset and hospitalisation.
Feature Plot
Before I start actually building models, I want to check whether the distribution of feature values is comparable between training, validation and test datasets.
# prepare for plotting
dataset_complete_gather <- dataset_complete %>%
mutate(set = ifelse(rownames(dataset_complete) %in% rownames(test_data), "Test Data",
ifelse(rownames(dataset_complete) %in% rownames(val_train_data), "Validation Train Data",
ifelse(rownames(dataset_complete) %in% rownames(val_test_data), "Validation Test Data", "NA"))),
case_ID = rownames(.)) %>%
gather(group, value, age:early_outcome)
dataset_complete_gather$group <- mapvalues(dataset_complete_gather$group,
from = c("age", "hospital", "gender_f", "province_Jiangsu", "province_Shanghai", "province_Zhejiang",
"province_other", "days_onset_to_outcome", "days_onset_to_hospital", "early_onset", "early_outcome" ),
to = c("Age", "Hospital", "Female", "Jiangsu", "Shanghai", "Zhejiang",
"Other province", "Days onset to outcome", "Days onset to hospital", "Early onset", "Early outcome" ))
ggplot(data = dataset_complete_gather, aes(x = as.numeric(value), fill = outcome, color = outcome)) +
geom_density(alpha = 0.2) +
geom_rug() +
scale_color_brewer(palette="Set1", na.value = "grey50") +
scale_fill_brewer(palette="Set1", na.value = "grey50") +
my_theme() +
facet_wrap(set ~ group, ncol = 11, scales = "free") +
labs(
x = "Value",
y = "Density",
title = "2013 Influenza A H7N9 cases in China",
subtitle = "Features for classifying outcome",
caption = "\nDensity distribution of all features used for classification of flu outcome."
)
ggplot(subset(dataset_complete_gather, group == "Age" | group == "Days onset to hospital" | group == "Days onset to outcome"),
aes(x=outcome, y=as.numeric(value), fill=set)) + geom_boxplot() +
my_theme() +
scale_fill_brewer(palette="Set1", type = "div ") +
facet_wrap( ~ group, ncol = 3, scales = "free") +
labs(
fill = "",
x = "Outcome",
y = "Value",
title = "2013 Influenza A H7N9 cases in China",
subtitle = "Features for classifying outcome",
caption = "\nBoxplot of the features age, days from onset to hospitalisation and days from onset to outcome."
)
Luckily, the distributions looks reasonably similar between the validation and test data, except for a few outliers.
Comparing Machine Learning algorithms
Before I try to predict the outcome of the unknown cases, I am testing the models’ accuracy with the validation datasets on a couple of algorithms. I have chosen only a few more well known algorithms, but caret implements many more.
Random Forest
Random Forests predictions are based on the generation of multiple classification trees.
set.seed(27)
model_rf <- caret::train(outcome ~ .,
data = val_train_data,
method = "rf",
preProcess = NULL,
trControl = trainControl(method = "repeatedcv", number = 10, repeats = 10, verboseIter = FALSE))
model_rf
## Random Forest
##
## 56 samples
## 11 predictors
## 2 classes: 'Death', 'Recover'
##
## No pre-processing
## Resampling: Cross-Validated (10 fold, repeated 10 times)
## Summary of sample sizes: 50, 50, 51, 51, 51, 50, ...
## Resampling results across tuning parameters:
##
## mtry Accuracy Kappa
## 2 0.7801905 0.5338408
## 6 0.7650952 0.4926366
## 11 0.7527619 0.4638073
##
## Accuracy was used to select the optimal model using the largest value.
## The final value used for the model was mtry = 2.
confusionMatrix(predict(model_rf, val_test_data[, -1]), val_test_data$outcome)
## Confusion Matrix and Statistics
##
## Reference
## Prediction Death Recover
## Death 4 4
## Recover 5 10
##
## Accuracy : 0.6087
## 95% CI : (0.3854, 0.8029)
## No Information Rate : 0.6087
## P-Value [Acc > NIR] : 0.5901
##
## Kappa : 0.1619
## Mcnemar's Test P-Value : 1.0000
##
## Sensitivity : 0.4444
## Specificity : 0.7143
## Pos Pred Value : 0.5000
## Neg Pred Value : 0.6667
## Prevalence : 0.3913
## Detection Rate : 0.1739
## Detection Prevalence : 0.3478
## Balanced Accuracy : 0.5794
##
## 'Positive' Class : Death
##
Elastic-Net Regularized Generalized Linear Models
Lasso or elastic net regularization for generalized linear model regression are based on linear regression models and is useful when we have feature correlation in our model.
set.seed(27)
model_glmnet <- caret::train(outcome ~ .,
data = val_train_data,
method = "glmnet",
preProcess = NULL,
trControl = trainControl(method = "repeatedcv", number = 10, repeats = 10, verboseIter = FALSE))
model_glmnet
## glmnet
##
## 56 samples
## 11 predictors
## 2 classes: 'Death', 'Recover'
##
## No pre-processing
## Resampling: Cross-Validated (10 fold, repeated 10 times)
## Summary of sample sizes: 50, 50, 51, 51, 51, 50, ...
## Resampling results across tuning parameters:
##
## alpha lambda Accuracy Kappa
## 0.10 0.0005154913 0.7211905 0.4218952
## 0.10 0.0051549131 0.7189524 0.4189835
## 0.10 0.0515491312 0.7469524 0.4704687
## 0.55 0.0005154913 0.7211905 0.4218952
## 0.55 0.0051549131 0.7236190 0.4280528
## 0.55 0.0515491312 0.7531905 0.4801031
## 1.00 0.0005154913 0.7228571 0.4245618
## 1.00 0.0051549131 0.7278571 0.4361809
## 1.00 0.0515491312 0.7678571 0.5094194
##
## Accuracy was used to select the optimal model using the largest value.
## The final values used for the model were alpha = 1 and lambda = 0.05154913.
confusionMatrix(predict(model_glmnet, val_test_data[, -1]), val_test_data$outcome)
## Confusion Matrix and Statistics
##
## Reference
## Prediction Death Recover
## Death 3 4
## Recover 6 10
##
## Accuracy : 0.5652
## 95% CI : (0.3449, 0.7681)
## No Information Rate : 0.6087
## P-Value [Acc > NIR] : 0.7418
##
## Kappa : 0.0496
## Mcnemar's Test P-Value : 0.7518
##
## Sensitivity : 0.3333
## Specificity : 0.7143
## Pos Pred Value : 0.4286
## Neg Pred Value : 0.6250
## Prevalence : 0.3913
## Detection Rate : 0.1304
## Detection Prevalence : 0.3043
## Balanced Accuracy : 0.5238
##
## 'Positive' Class : Death
##
k-Nearest Neighbors
k-nearest neighbors predicts based on point distances with predefined constants.
set.seed(27)
model_kknn <- caret::train(outcome ~ .,
data = val_train_data,
method = "kknn",
preProcess = NULL,
trControl = trainControl(method = "repeatedcv", number = 10, repeats = 10, verboseIter = FALSE))
model_kknn
## k-Nearest Neighbors
##
## 56 samples
## 11 predictors
## 2 classes: 'Death', 'Recover'
##
## No pre-processing
## Resampling: Cross-Validated (10 fold, repeated 10 times)
## Summary of sample sizes: 50, 50, 51, 51, 51, 50, ...
## Resampling results across tuning parameters:
##
## kmax Accuracy Kappa
## 5 0.6531905 0.2670442
## 7 0.7120476 0.4031836
## 9 0.7106190 0.4004564
##
## Tuning parameter 'distance' was held constant at a value of 2
## Tuning parameter 'kernel' was held constant at a value of optimal
## Accuracy was used to select the optimal model using the largest value.
## The final values used for the model were kmax = 7, distance = 2 and kernel = optimal.
confusionMatrix(predict(model_kknn, val_test_data[, -1]), val_test_data$outcome)
## Confusion Matrix and Statistics
##
## Reference
## Prediction Death Recover
## Death 3 3
## Recover 6 11
##
## Accuracy : 0.6087
## 95% CI : (0.3854, 0.8029)
## No Information Rate : 0.6087
## P-Value [Acc > NIR] : 0.5901
##
## Kappa : 0.1266
## Mcnemar's Test P-Value : 0.5050
##
## Sensitivity : 0.3333
## Specificity : 0.7857
## Pos Pred Value : 0.5000
## Neg Pred Value : 0.6471
## Prevalence : 0.3913
## Detection Rate : 0.1304
## Detection Prevalence : 0.2609
## Balanced Accuracy : 0.5595
##
## 'Positive' Class : Death
##
Penalized Discriminant Analysis
Penalized Discriminant Analysis is the penalized linear discriminant analysis and is also useful for when we have highly correlated features.
set.seed(27)
model_pda <- caret::train(outcome ~ .,
data = val_train_data,
method = "pda",
preProcess = NULL,
trControl = trainControl(method = "repeatedcv", number = 10, repeats = 10, verboseIter = FALSE))
model_pda
## Penalized Discriminant Analysis
##
## 56 samples
## 11 predictors
## 2 classes: 'Death', 'Recover'
##
## No pre-processing
## Resampling: Cross-Validated (10 fold, repeated 10 times)
## Summary of sample sizes: 50, 50, 51, 51, 51, 50, ...
## Resampling results across tuning parameters:
##
## lambda Accuracy Kappa
## 0e+00 NaN NaN
## 1e-04 0.7255238 0.4373766
## 1e-01 0.7235238 0.4342554
##
## Accuracy was used to select the optimal model using the largest value.
## The final value used for the model was lambda = 1e-04.
confusionMatrix(predict(model_pda, val_test_data[, -1]), val_test_data$outcome)
## Confusion Matrix and Statistics
##
## Reference
## Prediction Death Recover
## Death 4 4
## Recover 5 10
##
## Accuracy : 0.6087
## 95% CI : (0.3854, 0.8029)
## No Information Rate : 0.6087
## P-Value [Acc > NIR] : 0.5901
##
## Kappa : 0.1619
## Mcnemar's Test P-Value : 1.0000
##
## Sensitivity : 0.4444
## Specificity : 0.7143
## Pos Pred Value : 0.5000
## Neg Pred Value : 0.6667
## Prevalence : 0.3913
## Detection Rate : 0.1739
## Detection Prevalence : 0.3478
## Balanced Accuracy : 0.5794
##
## 'Positive' Class : Death
##
Stabilized Linear Discriminant Analysis
Stabilized Linear Discriminant Analysis is designed for high-dimensional data and correlated co-variables.
set.seed(27)
model_slda <- caret::train(outcome ~ .,
data = val_train_data,
method = "slda",
preProcess = NULL,
trControl = trainControl(method = "repeatedcv", number = 10, repeats = 10, verboseIter = FALSE))
model_slda
## Stabilized Linear Discriminant Analysis
##
## 56 samples
## 11 predictors
## 2 classes: 'Death', 'Recover'
##
## No pre-processing
## Resampling: Cross-Validated (10 fold, repeated 10 times)
## Summary of sample sizes: 50, 50, 51, 51, 51, 50, ...
## Resampling results:
##
## Accuracy Kappa
## 0.6886667 0.3512234
##
##
confusionMatrix(predict(model_slda, val_test_data[, -1]), val_test_data$outcome)
## Confusion Matrix and Statistics
##
## Reference
## Prediction Death Recover
## Death 3 2
## Recover 6 12
##
## Accuracy : 0.6522
## 95% CI : (0.4273, 0.8362)
## No Information Rate : 0.6087
## P-Value [Acc > NIR] : 0.4216
##
## Kappa : 0.2069
## Mcnemar's Test P-Value : 0.2888
##
## Sensitivity : 0.3333
## Specificity : 0.8571
## Pos Pred Value : 0.6000
## Neg Pred Value : 0.6667
## Prevalence : 0.3913
## Detection Rate : 0.1304
## Detection Prevalence : 0.2174
## Balanced Accuracy : 0.5952
##
## 'Positive' Class : Death
##
Nearest Shrunken Centroids
Nearest Shrunken Centroids computes a standardized centroid for each class and shrinks each centroid toward the overall centroid for all classes.
set.seed(27)
model_pam <- caret::train(outcome ~ .,
data = val_train_data,
method = "pam",
preProcess = NULL,
trControl = trainControl(method = "repeatedcv", number = 10, repeats = 10, verboseIter = FALSE))
model_pam
## Nearest Shrunken Centroids
##
## 56 samples
## 11 predictors
## 2 classes: 'Death', 'Recover'
##
## No pre-processing
## Resampling: Cross-Validated (10 fold, repeated 10 times)
## Summary of sample sizes: 50, 50, 51, 51, 51, 50, ...
## Resampling results across tuning parameters:
##
## threshold Accuracy Kappa
## 0.1200215 0.7283333 0.4418904
## 1.7403111 0.6455714 0.2319674
## 3.3606007 0.5904762 0.0000000
##
## Accuracy was used to select the optimal model using the largest value.
## The final value used for the model was threshold = 0.1200215.
confusionMatrix(predict(model_pam, val_test_data[, -1]), val_test_data$outcome)
## Confusion Matrix and Statistics
##
## Reference
## Prediction Death Recover
## Death 4 3
## Recover 5 11
##
## Accuracy : 0.6522
## 95% CI : (0.4273, 0.8362)
## No Information Rate : 0.6087
## P-Value [Acc > NIR] : 0.4216
##
## Kappa : 0.2397
## Mcnemar's Test P-Value : 0.7237
##
## Sensitivity : 0.4444
## Specificity : 0.7857
## Pos Pred Value : 0.5714
## Neg Pred Value : 0.6875
## Prevalence : 0.3913
## Detection Rate : 0.1739
## Detection Prevalence : 0.3043
## Balanced Accuracy : 0.6151
##
## 'Positive' Class : Death
##
Single C5.0 Tree
C5.0 is another tree-based modeling algorithm.
set.seed(27)
model_C5.0Tree <- caret::train(outcome ~ .,
data = val_train_data,
method = "C5.0Tree",
preProcess = NULL,
trControl = trainControl(method = "repeatedcv", number = 10, repeats = 10, verboseIter = FALSE))
model_C5.0Tree
## Single C5.0 Tree
##
## 56 samples
## 11 predictors
## 2 classes: 'Death', 'Recover'
##
## No pre-processing
## Resampling: Cross-Validated (10 fold, repeated 10 times)
## Summary of sample sizes: 50, 50, 51, 51, 51, 50, ...
## Resampling results:
##
## Accuracy Kappa
## 0.7103333 0.4047334
##
##
confusionMatrix(predict(model_C5.0Tree, val_test_data[, -1]), val_test_data$outcome)
## Confusion Matrix and Statistics
##
## Reference
## Prediction Death Recover
## Death 5 4
## Recover 4 10
##
## Accuracy : 0.6522
## 95% CI : (0.4273, 0.8362)
## No Information Rate : 0.6087
## P-Value [Acc > NIR] : 0.4216
##
## Kappa : 0.2698
## Mcnemar's Test P-Value : 1.0000
##
## Sensitivity : 0.5556
## Specificity : 0.7143
## Pos Pred Value : 0.5556
## Neg Pred Value : 0.7143
## Prevalence : 0.3913
## Detection Rate : 0.2174
## Detection Prevalence : 0.3913
## Balanced Accuracy : 0.6349
##
## 'Positive' Class : Death
##
Partial Least Squares
Partial least squares regression combined principal component analysis and multiple regression and is useful for modeling with correlated features.
set.seed(27)
model_pls <- caret::train(outcome ~ .,
data = val_train_data,
method = "pls",
preProcess = NULL,
trControl = trainControl(method = "repeatedcv", number = 10, repeats = 10, verboseIter = FALSE))
model_pls
## Partial Least Squares
##
## 56 samples
## 11 predictors
## 2 classes: 'Death', 'Recover'
##
## No pre-processing
## Resampling: Cross-Validated (10 fold, repeated 10 times)
## Summary of sample sizes: 50, 50, 51, 51, 51, 50, ...
## Resampling results across tuning parameters:
##
## ncomp Accuracy Kappa
## 1 0.6198571 0.2112982
## 2 0.6376190 0.2436222
## 3 0.6773810 0.3305780
##
## Accuracy was used to select the optimal model using the largest value.
## The final value used for the model was ncomp = 3.
confusionMatrix(predict(model_pls, val_test_data[, -1]), val_test_data$outcome)
## Confusion Matrix and Statistics
##
## Reference
## Prediction Death Recover
## Death 3 2
## Recover 6 12
##
## Accuracy : 0.6522
## 95% CI : (0.4273, 0.8362)
## No Information Rate : 0.6087
## P-Value [Acc > NIR] : 0.4216
##
## Kappa : 0.2069
## Mcnemar's Test P-Value : 0.2888
##
## Sensitivity : 0.3333
## Specificity : 0.8571
## Pos Pred Value : 0.6000
## Neg Pred Value : 0.6667
## Prevalence : 0.3913
## Detection Rate : 0.1304
## Detection Prevalence : 0.2174
## Balanced Accuracy : 0.5952
##
## 'Positive' Class : Death
##
Comparing accuracy of models
# Create a list of models
models <- list(rf = model_rf, glmnet = model_glmnet, kknn = model_kknn, pda = model_pda, slda = model_slda,
pam = model_pam, C5.0Tree = model_C5.0Tree, pls = model_pls)
# Resample the models
resample_results <- resamples(models)
# Generate a summary
summary(resample_results, metric = c("Kappa", "Accuracy"))
##
## Call:
## summary.resamples(object = resample_results, metric = c("Kappa", "Accuracy"))
##
## Models: rf, glmnet, kknn, pda, slda, pam, C5.0Tree, pls
## Number of resamples: 100
##
## Kappa
## Min. 1st Qu. Median Mean 3rd Qu. Max. NA's
## rf -0.3636 0.3214 0.5714 0.5338 0.6800 1 0
## glmnet -0.3636 0.2768 0.5455 0.5094 0.6667 1 0
## kknn -0.3636 0.1667 0.5035 0.4032 0.6282 1 0
## pda -0.6667 0.2431 0.5455 0.4374 0.6667 1 0
## slda -0.5217 0.0000 0.4308 0.3512 0.6667 1 0
## pam -0.5000 0.2292 0.5455 0.4419 0.6667 1 0
## C5.0Tree -0.4286 0.1667 0.4167 0.4047 0.6154 1 0
## pls -0.6667 0.0000 0.3333 0.3306 0.6282 1 0
##
## Accuracy
## Min. 1st Qu. Median Mean 3rd Qu. Max. NA's
## rf 0.4000 0.6667 0.8000 0.7802 0.8393 1 0
## glmnet 0.4000 0.6500 0.8000 0.7679 0.8333 1 0
## kknn 0.3333 0.6000 0.7571 0.7120 0.8333 1 0
## pda 0.2000 0.6000 0.8000 0.7255 0.8333 1 0
## slda 0.2000 0.6000 0.6905 0.6887 0.8333 1 0
## pam 0.2000 0.6000 0.8000 0.7283 0.8333 1 0
## C5.0Tree 0.2000 0.6000 0.7143 0.7103 0.8333 1 0
## pls 0.1429 0.5929 0.6667 0.6774 0.8333 1 0
bwplot(resample_results , metric = c("Kappa","Accuracy"))
All models are similarly accurate.
Combined results of predicting validation test samples
To compare the predictions from all models, I summed up the prediction probabilities for Death and Recovery from all models and calculated the log2 of the ratio between the summed probabilities for Recovery by the summed probabilities for Death. All cases with a log2 ratio bigger than 1.5 were defined as Recover, cases with a log2 ratio below -1.5 were defined as Death, and the remaining cases were defined as uncertain.
results <- data.frame(randomForest = predict(model_rf, newdata = val_test_data[, -1], type="prob"),
glmnet = predict(model_glmnet, newdata = val_test_data[, -1], type="prob"),
kknn = predict(model_kknn, newdata = val_test_data[, -1], type="prob"),
pda = predict(model_pda, newdata = val_test_data[, -1], type="prob"),
slda = predict(model_slda, newdata = val_test_data[, -1], type="prob"),
pam = predict(model_pam, newdata = val_test_data[, -1], type="prob"),
C5.0Tree = predict(model_C5.0Tree, newdata = val_test_data[, -1], type="prob"),
pls = predict(model_pls, newdata = val_test_data[, -1], type="prob"))
results$sum_Death <- rowSums(results[, grep("Death", colnames(results))])
results$sum_Recover <- rowSums(results[, grep("Recover", colnames(results))])
results$log2_ratio <- log2(results$sum_Recover/results$sum_Death)
results$true_outcome <- val_test_data$outcome
results$pred_outcome <- ifelse(results$log2_ratio > 1.5, "Recover", ifelse(results$log2_ratio < -1.5, "Death", "uncertain"))
results$prediction <- ifelse(results$pred_outcome == results$true_outcome, "CORRECT",
ifelse(results$pred_outcome == "uncertain", "uncertain", "wrong"))
results[, -c(1:16)]
## sum_Death sum_Recover log2_ratio true_outcome pred_outcome prediction
## case_123 2.2236374 5.776363 1.37723972 Death uncertain uncertain
## case_127 0.7522674 7.247733 3.26821230 Recover Recover CORRECT
## case_128 2.4448101 5.555190 1.18411381 Recover uncertain uncertain
## case_14 2.9135620 5.086438 0.80387172 Recover uncertain uncertain
## case_19 2.2113377 5.788662 1.38831062 Death uncertain uncertain
## case_2 3.6899508 4.310049 0.22410277 Death uncertain uncertain
## case_20 4.5621189 3.437881 -0.40818442 Recover uncertain uncertain
## case_21 4.9960238 3.003976 -0.73390698 Recover uncertain uncertain
## case_34 6.1430233 1.856977 -1.72599312 Death Death CORRECT
## case_37 0.8717595 7.128241 3.03154401 Recover Recover CORRECT
## case_5 1.7574945 6.242505 1.82860496 Recover Recover CORRECT
## case_51 3.8917736 4.108226 0.07808791 Recover uncertain uncertain
## case_55 3.1548368 4.845163 0.61897986 Recover uncertain uncertain
## case_6 3.5163388 4.483661 0.35060317 Death uncertain uncertain
## case_61 5.0803241 2.919676 -0.79911232 Death uncertain uncertain
## case_65 1.1282313 6.871769 2.60661863 Recover Recover CORRECT
## case_74 5.3438427 2.656157 -1.00853699 Recover uncertain uncertain
## case_78 3.2183947 4.781605 0.57115378 Death uncertain uncertain
## case_79 1.9521617 6.047838 1.63134700 Recover Recover CORRECT
## case_8 3.6106045 4.389395 0.28178184 Death uncertain uncertain
## case_87 4.9712879 3.028712 -0.71491522 Death uncertain uncertain
## case_91 1.8648837 6.135116 1.71800508 Recover Recover CORRECT
## case_94 2.1198087 5.880191 1.47192901 Recover uncertain uncertain
All predictions based on all models were either correct or uncertain.
Predicting unknown outcomes
The above models will now be used to predict the outcome of cases with unknown fate.
set.seed(27)
model_rf <- caret::train(outcome ~ .,
data = train_data,
method = "rf",
preProcess = NULL,
trControl = trainControl(method = "repeatedcv", number = 10, repeats = 10, verboseIter = FALSE))
model_glmnet <- caret::train(outcome ~ .,
data = train_data,
method = "glmnet",
preProcess = NULL,
trControl = trainControl(method = "repeatedcv", number = 10, repeats = 10, verboseIter = FALSE))
model_kknn <- caret::train(outcome ~ .,
data = train_data,
method = "kknn",
preProcess = NULL,
trControl = trainControl(method = "repeatedcv", number = 10, repeats = 10, verboseIter = FALSE))
model_pda <- caret::train(outcome ~ .,
data = train_data,
method = "pda",
preProcess = NULL,
trControl = trainControl(method = "repeatedcv", number = 10, repeats = 10, verboseIter = FALSE))
model_slda <- caret::train(outcome ~ .,
data = train_data,
method = "slda",
preProcess = NULL,
trControl = trainControl(method = "repeatedcv", number = 10, repeats = 10, verboseIter = FALSE))
model_pam <- caret::train(outcome ~ .,
data = train_data,
method = "pam",
preProcess = NULL,
trControl = trainControl(method = "repeatedcv", number = 10, repeats = 10, verboseIter = FALSE))
model_C5.0Tree <- caret::train(outcome ~ .,
data = train_data,
method = "C5.0Tree",
preProcess = NULL,
trControl = trainControl(method = "repeatedcv", number = 10, repeats = 10, verboseIter = FALSE))
model_pls <- caret::train(outcome ~ .,
data = train_data,
method = "pls",
preProcess = NULL,
trControl = trainControl(method = "repeatedcv", number = 10, repeats = 10, verboseIter = FALSE))
models <- list(rf = model_rf, glmnet = model_glmnet, kknn = model_kknn, pda = model_pda, slda = model_slda,
pam = model_pam, C5.0Tree = model_C5.0Tree, pls = model_pls)
# Resample the models
resample_results <- resamples(models)
bwplot(resample_results , metric = c("Kappa","Accuracy"))
Here again, the accuracy is similar in all models.
The final results are calculate as described above.
results <- data.frame(randomForest = predict(model_rf, newdata = test_data, type="prob"),
glmnet = predict(model_glmnet, newdata = test_data, type="prob"),
kknn = predict(model_kknn, newdata = test_data, type="prob"),
pda = predict(model_pda, newdata = test_data, type="prob"),
slda = predict(model_slda, newdata = test_data, type="prob"),
pam = predict(model_pam, newdata = test_data, type="prob"),
C5.0Tree = predict(model_C5.0Tree, newdata = test_data, type="prob"),
pls = predict(model_pls, newdata = test_data, type="prob"))
results$sum_Death <- rowSums(results[, grep("Death", colnames(results))])
results$sum_Recover <- rowSums(results[, grep("Recover", colnames(results))])
results$log2_ratio <- log2(results$sum_Recover/results$sum_Death)
results$predicted_outcome <- ifelse(results$log2_ratio > 1.5, "Recover", ifelse(results$log2_ratio < -1.5, "Death", "uncertain"))
results[, -c(1:16)]
## sum_Death sum_Recover log2_ratio predicted_outcome
## case_100 3.6707808 4.329219 0.23801986 uncertain
## case_101 6.3194220 1.680578 -1.91083509 Death
## case_102 6.4960471 1.503953 -2.11080274 Death
## case_103 2.3877223 5.612278 1.23295134 uncertain
## case_104 4.1254604 3.874540 -0.09053024 uncertain
## case_105 2.1082161 5.891784 1.48268174 uncertain
## case_108 6.5941436 1.405856 -2.22973606 Death
## case_109 2.2780779 5.721922 1.32868278 uncertain
## case_110 1.7853361 6.214664 1.79948072 Recover
## case_112 4.5102439 3.489756 -0.37007925 uncertain
## case_113 4.7047554 3.295245 -0.51373420 uncertain
## case_114 1.6397105 6.360290 1.95565132 Recover
## case_115 1.2351519 6.764848 2.45336902 Recover
## case_118 1.4675571 6.532443 2.15420601 Recover
## case_120 1.9322149 6.067785 1.65091447 Recover
## case_122 1.1165786 6.883421 2.62404109 Recover
## case_126 5.4933927 2.506607 -1.13196145 uncertain
## case_130 4.3035732 3.696427 -0.21940367 uncertain
## case_132 5.0166184 2.983382 -0.74976671 uncertain
## case_136 2.6824566 5.317543 0.98720505 uncertain
## case_15 5.8545370 2.145463 -1.44826607 uncertain
## case_16 6.5063794 1.493621 -2.12304122 Death
## case_22 4.2929861 3.707014 -0.21172398 uncertain
## case_28 6.6129447 1.387055 -2.25326754 Death
## case_31 6.2625800 1.737420 -1.84981057 Death
## case_32 6.0986316 1.901368 -1.68144746 Death
## case_38 1.7581540 6.241846 1.82791134 Recover
## case_39 2.4765188 5.523481 1.15726428 uncertain
## case_4 5.1309910 2.869009 -0.83868503 uncertain
## case_40 6.6069297 1.393070 -2.24571191 Death
## case_41 5.5499144 2.450086 -1.17963336 uncertain
## case_42 2.0709160 5.929084 1.51754019 Recover
## case_47 6.1988812 1.801119 -1.78311450 Death
## case_48 5.6500772 2.349923 -1.26565724 uncertain
## case_52 5.9096543 2.090346 -1.49933214 uncertain
## case_54 2.6763066 5.323693 0.99218409 uncertain
## case_56 2.6826414 5.317359 0.98705557 uncertain
## case_62 6.0722552 1.927745 -1.65531833 Death
## case_63 3.9541381 4.045862 0.03308379 uncertain
## case_66 5.8320010 2.167999 -1.42762690 uncertain
## case_67 2.1732059 5.826794 1.42287745 uncertain
## case_68 4.3464174 3.653583 -0.25051491 uncertain
## case_69 1.8902845 6.109715 1.69250181 Recover
## case_70 4.5986242 3.401376 -0.43508393 uncertain
## case_71 2.7966938 5.203306 0.89570626 uncertain
## case_80 2.7270986 5.272901 0.95123017 uncertain
## case_84 1.9485253 6.051475 1.63490411 Recover
## case_85 3.0076953 4.992305 0.73104755 uncertain
## case_86 2.0417802 5.958220 1.54505376 Recover
## case_88 0.7285419 7.271458 3.31916083 Recover
## case_9 2.2188620 5.781138 1.38153353 uncertain
## case_90 1.3973262 6.602674 2.24038155 Recover
## case_92 5.0994993 2.900501 -0.81405366 uncertain
## case_93 4.3979048 3.602095 -0.28798006 uncertain
## case_95 3.5561792 4.443821 0.32147260 uncertain
## case_96 1.3359082 6.664092 2.31858744 Recover
## case_99 5.0776686 2.922331 -0.79704644 uncertain
From 57 cases, 14 were defined as Recover, 10 as Death and 33 as uncertain.
results_combined <- merge(results[, -c(1:16)], fluH7N9.china.2013[which(fluH7N9.china.2013$case.ID %in% rownames(results)), ],
by.x = "row.names", by.y = "case.ID")
results_combined <- results_combined[, -c(2, 3, 8, 9)]
results_combined_gather <- results_combined %>%
gather(group_dates, date, date.of.onset:date.of.hospitalisation)
results_combined_gather$group_dates <- factor(results_combined_gather$group_dates, levels = c("date.of.onset", "date.of.hospitalisation"))
results_combined_gather$group_dates <- mapvalues(results_combined_gather$group_dates, from = c("date.of.onset", "date.of.hospitalisation"),
to = c("Date of onset", "Date of hospitalisation"))
results_combined_gather$gender <- mapvalues(results_combined_gather$gender, from = c("f", "m"),
to = c("Female", "Male"))
levels(results_combined_gather$gender) <- c(levels(results_combined_gather$gender), "unknown")
results_combined_gather$gender[is.na(results_combined_gather$gender)] <- "unknown"
ggplot(data = results_combined_gather, aes(x = date, y = log2_ratio, color = predicted_outcome)) +
geom_jitter(aes(size = as.numeric(age)), alpha = 0.3) +
geom_rug() +
facet_grid(gender ~ group_dates) +
labs(
color = "Predicted outcome",
size = "Age",
x = "Date in 2013",
y = "log2 ratio of prediction Recover vs Death",
title = "2013 Influenza A H7N9 cases in China",
subtitle = "Predicted outcome",
caption = ""
) +
my_theme() +
scale_shape_manual(values = c(15, 16, 17)) +
scale_color_brewer(palette="Set1") +
scale_fill_brewer(palette="Set1")
The comparison of date of onset, data of hospitalisation, gender and age with predicted outcome shows that predicted deaths were associated with older age than predicted Recoveries. Date of onset does not show a bias in either direction.
Conclusion
The dataset posed a couple of difficulties, like an unequal distribution of data points across variables and missing values. This makes the modeling inherently prone to flaws. However, real life data is seldom perfect, so I went ahead and tested the modeling success anyway.
By accounting for uncertain classification, the validation data could be classified accurately. However, for a more accurate model, these few cases don’t give enough information to reliably predict the outcome. More cases, more information (i.e. more features) and fewer missing data would improve the modeling outcome.
Unfortunately, it can’t be verified whether the unknown cases were classified correctly. Also, this data is specific for this case of flu. In order to be able to draw more general conclusions about flu outcome, other cases and other information, for example on medical parameters would be necessary. However, this dataset served as a nice example for the possibilities (and pitfalls) of machine learning applications and showcased a basic workflow for building prediction models with R.
sessionInfo()
## R version 3.3.2 (2016-10-31)
## Platform: x86_64-w64-mingw32/x64 (64-bit)
## Running under: Windows 7 x64 (build 7601) Service Pack 1
##
## locale:
## [1] LC_COLLATE=English_United States.1252 LC_CTYPE=English_United States.1252 LC_MONETARY=English_United States.1252 LC_NUMERIC=C LC_TIME=English_United States.1252
##
## attached base packages:
## [1] grid stats graphics grDevices utils datasets methods base
##
## other attached packages:
## [1] pls_2.5-0 C50_0.1.0-24 pamr_1.55 survival_2.40-1 cluster_2.0.5 ipred_0.9-5 mda_0.4-9 class_7.3-14 kknn_1.3.1 glmnet_2.0-5 foreach_1.4.3 Matrix_1.2-7.1 randomForest_4.6-12 RColorBrewer_1.1-2 rpart.plot_2.1.0 rattle_4.1.0 rpart_4.1-10 caret_6.0-73 lattice_0.20-34 mice_2.25 Rcpp_0.12.7 dplyr_0.5.0 gridExtra_2.2.1 ggplot2_2.2.0 plyr_1.8.4 tidyr_0.6.0 outbreaks_1.0.0
##
## loaded via a namespace (and not attached):
## [1] RGtk2_2.20.31 assertthat_0.1 digest_0.6.10 R6_2.2.0 MatrixModels_0.4-1 stats4_3.3.2 evaluate_0.10 e1071_1.6-7 lazyeval_0.2.0 minqa_1.2.4 SparseM_1.74 car_2.1-3 nloptr_1.0.4 partykit_1.1-1 rmarkdown_1.1 labeling_0.3 splines_3.3.2 lme4_1.1-12 stringr_1.1.0 igraph_1.0.1 munsell_0.4.3 compiler_3.3.2 mgcv_1.8-16 htmltools_0.3.5 nnet_7.3-12 tibble_1.2 prodlim_1.5.7 codetools_0.2-15 MASS_7.3-45 ModelMetrics_1.1.0 nlme_3.1-128 gtable_0.2.0 DBI_0.5-1 magrittr_1.5 scales_0.4.1 stringi_1.1.2 reshape2_1.4.2 Formula_1.2-1 lava_1.4.5 iterators_1.0.8 tools_3.3.2 parallel_3.3.2 pbkrtest_0.4-6 yaml_2.1.14 colorspace_1.3-0 knitr_1.15 quantreg_5.29
R-bloggers.com offers daily e-mail updates about R news and tutorials about learning R and many other topics. Click here if you're looking to post or find an R/data-science job.
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.