TL;DR
Partial least squares (PLS) discriminant-analysis (DA) can ridiculously over fit even on completely random data. The quality of the PLS-DA model can be assessed using cross-validation, but cross-validation is not typically performed in many metabolomics publications. Random forest, in contrast, because of the forest of decision tree learners, and the out-of-bag (OOB) samples used for testing each tree, automatically provides an indication of the quality of the model.
Why?
I’ve recently been working on some machine learning work using random forests (RF) Breimann, 2001 on metabolomics data. This has been relatively successful, with decent sensitivity and specificity, and hopefully I’ll be able to post more soon. However, PLS (Wold, 1975) is a standard technique used in metabolomics due to the prevalence of analytical chemists in metabolomics and a long familiarity with the method. Importantly, my collaborators frequently use PLS-DA to generate plots to show that the various classes of samples are separable.
However, it has long been known that PLS (and all of it’s variants, PLS-DA, OPLS, OPLS-DA, etc) can easily generate models that over fit the data, and that over fitting of the model needs to be assessed if the model is going to be used in subsequent analyses.
Random Data
To illustrate the behavior of both RF and PLS-DA, we will generate some random data where each of the samples are randomly assigned to one of two classes.
Feature Intensities
We will generate a data set with 1000 features, where each feature’s mean value is from a uniform distribution with a range of 1-10000.
library(cowplot)
## Loading required package: ggplot2
##
## Attaching package: 'cowplot'
## The following object is masked from 'package:ggplot2':
##
## ggsave
library(fakeDataWithError)
set.seed(1234)
n_point <- 1000
max_value <- 10000
init_values <- runif(n_point, 0, max_value)
init_data <- data.frame(data = init_values)
ggplot(init_data, aes(x = data)) + geom_histogram() + ggtitle("Initial Data")
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
For each of these features, their distribution across samples will be based on a random normal distribution where the mean is the initial feature value and a standard deviation of 200. The number of samples is 100.
n_sample <- 100
error_values <- add_uniform_noise(n_sample, init_values, 200)
Just for information, the add_uniform_noise
function is this:
add_uniform_noise
## function (n_rep, value, sd, use_zero = FALSE)
## {
## n_value <- length(value)
## n_sd <- n_rep * n_value
## out_sd <- rnorm(n_sd, 0, sd)
## out_sd <- matrix(out_sd, nrow = n_value, ncol = n_rep)
## if (!use_zero) {
## tmp_value <- matrix(value, nrow = n_value, ncol = n_rep,
## byrow = FALSE)
## out_value <- tmp_value + out_sd
## }
## else {
## out_value <- out_sd
## }
## return(out_value)
## }
## <bytecode: 0x56398fe59900>
## <environment: namespace:fakeDataWithError>
I created it as part of a package that is able to add different kinds of noise to data.
The distribution of values for a single feature looks like this:
error_data <- data.frame(feature_1 = error_values[1,])
ggplot(error_data, aes(x = feature_1)) + geom_histogram() + ggtitle("Error Data")
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
And we will assign the first 50 samples to class_1 and the second 50 samples to class_2.
sample_class <- rep(c("class_1", "class_2"), each = 50)
sample_class
## [1] "class_1" "class_1" "class_1" "class_1" "class_1" "class_1" "class_1"
## [8] "class_1" "class_1" "class_1" "class_1" "class_1" "class_1" "class_1"
## [15] "class_1" "class_1" "class_1" "class_1" "class_1" "class_1" "class_1"
## [22] "class_1" "class_1" "class_1" "class_1" "class_1" "class_1" "class_1"
## [29] "class_1" "class_1" "class_1" "class_1" "class_1" "class_1" "class_1"
## [36] "class_1" "class_1" "class_1" "class_1" "class_1" "class_1" "class_1"
## [43] "class_1" "class_1" "class_1" "class_1" "class_1" "class_1" "class_1"
## [50] "class_1" "class_2" "class_2" "class_2" "class_2" "class_2" "class_2"
## [57] "class_2" "class_2" "class_2" "class_2" "class_2" "class_2" "class_2"
## [64] "class_2" "class_2" "class_2" "class_2" "class_2" "class_2" "class_2"
## [71] "class_2" "class_2" "class_2" "class_2" "class_2" "class_2" "class_2"
## [78] "class_2" "class_2" "class_2" "class_2" "class_2" "class_2" "class_2"
## [85] "class_2" "class_2" "class_2" "class_2" "class_2" "class_2" "class_2"
## [92] "class_2" "class_2" "class_2" "class_2" "class_2" "class_2" "class_2"
## [99] "class_2" "class_2"
PCA
Just to show that the data is pretty random, lets use principal components analysis (PCA) to do a decomposition, and plot the first two components:
tmp_pca <- prcomp(t(error_values), center = TRUE, scale. = TRUE)
pca_data <- as.data.frame(tmp_pca$x[, 1:2])
pca_data$class <- as.factor(sample_class)
ggplot(pca_data, aes(x = PC1, y = PC2, color = class)) + geom_point(size = 4)
Random Forest
Let’s use RF first, and see how things look.
library(randomForest)
## randomForest 4.6-14
## Type rfNews() to see new features/changes/bug fixes.
##
## Attaching package: 'randomForest'
## The following object is masked from 'package:ggplot2':
##
## margin
rf_model <- randomForest(t(error_values), y = as.factor(sample_class))
The confusion matrix comparing actual vs predicted classes based on the out of bag (OOB) samples:
knitr::kable(rf_model$confusion)
class_1 | class_2 | class.error | |
---|---|---|---|
class_1 | 21 | 29 | 0.58 |
class_2 | 28 | 22 | 0.56 |
And an overall error of 0.5760364.
PLS-DA
So PLS-DA is really just PLS with y variable that is binary.
library(caret)
## Loading required package: lattice
pls_model <- plsda(t(error_values), as.factor(sample_class), ncomp = 2)
pls_scores <- data.frame(comp1 = pls_model$scores[,1], comp2 = pls_model$scores[,2], class = sample_class)
And plot the PLS scores:
ggplot(pls_scores, aes(x = comp1, y = comp2, color = class)) + geom_point(size = 4) + ggtitle("PLS-DA of Random Data")
And voila! Perfectly separated data! If I didn’t tell you that it was random, would you suspect it?
Cross-validated PLS-DA
Of course, one way to truly assess the worth of the model would be to use cross-validation, where a fraction of data is held back, and the model trained on the rest. Predictions are then made on the held back fraction, and because we know the truth, we will then calculate the area under the reciever operator curve (AUROC) or area under the curve (AUC) created by plotting true positives vs false positives.
To do this we will need two functions:
- Generates all of the CV folds
- Generates PLS-DA model, does prediction on hold out, calculates AUC
library(cvTools)
## Loading required package: robustbase
library(ROCR)
## Loading required package: gplots
##
## Attaching package: 'gplots'
## The following object is masked from 'package:stats':
##
## lowess
gen_cv <- function(xdata, ydata, nrep, kfold){
n_sample <- length(ydata)
all_index <- seq(1, n_sample)
cv_data <- cvFolds(n_sample, K = kfold, R = nrep, type = "random")
rep_values <- vapply(seq(1, nrep), function(in_rep){
use_rep <- cv_data$subsets[, in_rep]
cv_values <- vapply(seq(1, kfold), function(in_fold){
test_index <- use_rep[cv_data$which == in_fold]
train_index <- all_index[-test_index]
plsda_cv(xdata[train_index, ], ydata[train_index], xdata[test_index, ],
ydata[test_index])
}, numeric(1))
}, numeric(kfold))
}
plsda_cv <- function(xtrain, ytrain, xtest, ytest){
pls_model <- plsda(xtrain, ytrain, ncomp = 2)
pls_pred <- predict(pls_model, xtest, type = "prob")
use_pred <- pls_pred[, 2, 1]
pred_perf <- ROCR::prediction(use_pred, ytest)
pred_auc <- ROCR::performance(pred_perf, "auc")@y.values[[1]]
return(pred_auc)
}
And now lets do a bunch of replicates (100).
cv_vals <- gen_cv(t(error_values), factor(sample_class), nrep = 100, kfold = 5)
mean(cv_vals)
## [1] 0.4198336
sd(cv_vals)
## [1] 0.1134857
cv_frame <- data.frame(auc = as.vector(cv_vals))
ggplot(cv_frame, aes(x = auc)) + geom_histogram(binwidth = 0.01)
So we get an average AUC of 0.4198336, which is pretty awful. This implies that even though there was good separation on the scores, maybe the model is not actually that good, and we should be cautious of any predictions being made.
Of course, the PCA at the beginning of the analysis shows that there is no real separation in the data in the first place.
devtools::session_info()
## ─ Session info ──────────────────────────────────────────────────────────
## setting value
## version R version 3.5.1 (2018-07-02)
## os Ubuntu 18.04.3 LTS
## system x86_64, linux-gnu
## ui X11
## language (EN)
## collate en_US.UTF-8
## ctype en_US.UTF-8
## tz America/New_York
## date 2019-10-16
##
## ─ Packages ──────────────────────────────────────────────────────────────
## package * version date lib
## assertthat 0.2.1 2019-03-21 [1]
## backports 1.1.4 2019-04-10 [1]
## bitops 1.0-6 2013-08-17 [1]
## blogdown 0.10 2019-01-09 [1]
## bookdown 0.9 2018-12-21 [1]
## callr 3.1.1 2018-12-21 [1]
## caret * 6.0-81 2018-11-20 [1]
## caTools 1.17.1.1 2018-07-20 [1]
## class 7.3-14 2015-08-30 [1]
## cli 1.1.0 2019-03-19 [1]
## codetools 0.2-15 2016-10-05 [1]
## colorspace 1.4-1 2019-03-18 [1]
## cowplot * 0.9.4 2019-01-08 [1]
## crayon 1.3.4 2017-09-16 [1]
## cvTools * 0.3.2 2012-05-14 [1]
## data.table 1.11.8 2018-09-30 [1]
## DEoptimR 1.0-8 2016-11-19 [1]
## desc 1.2.0 2018-05-01 [1]
## devtools 2.0.1 2018-10-26 [1]
## digest 0.6.20 2019-07-04 [1]
## dplyr 0.8.3 2019-07-04 [1]
## evaluate 0.14 2019-05-28 [1]
## fakeDataWithError * 0.0.1 2018-12-13 [1]
## foreach 1.4.4 2017-12-12 [1]
## fs 1.3.1.9000 2019-07-10 [1]
## gdata 2.18.0 2017-06-06 [1]
## generics 0.0.2 2018-11-29 [1]
## ggplot2 * 3.1.1 2019-04-07 [1]
## glue 1.3.1 2019-03-12 [1]
## gower 0.1.2 2017-02-23 [1]
## gplots * 3.0.1 2016-03-30 [1]
## gtable 0.3.0 2019-03-25 [1]
## gtools 3.8.1 2018-06-26 [1]
## highr 0.8 2019-03-20 [1]
## htmltools 0.3.6 2017-04-28 [1]
## ipred 0.9-8 2018-11-05 [1]
## iterators 1.0.10 2018-07-13 [1]
## KernSmooth 2.23-15 2015-06-29 [1]
## knitr 1.24 2019-08-08 [1]
## labeling 0.3 2014-08-23 [1]
## lattice * 0.20-38 2018-11-04 [1]
## lava 1.6.4 2018-11-25 [1]
## lazyeval 0.2.2 2019-03-15 [1]
## lubridate 1.7.4 2018-04-11 [1]
## magrittr 1.5 2014-11-22 [1]
## MASS 7.3-51.1 2018-11-01 [1]
## Matrix 1.2-15 2018-11-01 [1]
## memoise 1.1.0 2017-04-21 [1]
## ModelMetrics 1.2.2 2018-11-03 [1]
## munsell 0.5.0 2018-06-12 [1]
## nlme 3.1-137 2018-04-07 [1]
## nnet 7.3-12 2016-02-02 [1]
## pillar 1.4.2 2019-06-29 [1]
## pkgbuild 1.0.2 2018-10-16 [1]
## pkgconfig 2.0.2 2018-08-16 [1]
## pkgload 1.0.2 2018-10-29 [1]
## pls 2.7-0 2018-08-21 [1]
## plyr 1.8.4 2016-06-08 [1]
## prettyunits 1.0.2 2015-07-13 [1]
## processx 3.2.1 2018-12-05 [1]
## prodlim 2018.04.18 2018-04-18 [1]
## ps 1.3.0 2018-12-21 [1]
## purrr 0.3.2 2019-03-15 [1]
## R6 2.4.0 2019-02-14 [1]
## randomForest * 4.6-14 2018-03-25 [1]
## Rcpp 1.0.2 2019-07-25 [1]
## recipes 0.1.4 2018-11-19 [1]
## remotes 2.0.2 2018-10-30 [1]
## reshape2 1.4.3 2017-12-11 [1]
## rlang 0.4.0.9002 2019-08-27 [1]
## rmarkdown 1.15 2019-08-21 [1]
## robustbase * 0.93-3 2018-09-21 [1]
## ROCR * 1.0-7 2015-03-26 [1]
## rpart 4.1-13 2018-02-23 [1]
## rprojroot 1.3-2 2018-01-03 [1]
## scales 1.0.0 2018-08-09 [1]
## sessioninfo 1.1.1 2018-11-05 [1]
## stringi 1.4.3 2019-03-12 [1]
## stringr 1.4.0 2019-02-10 [1]
## survival 2.43-3 2018-11-26 [1]
## testthat 2.0.1 2018-10-13 [1]
## tibble 2.1.3 2019-06-06 [1]
## tidyselect 0.2.5 2018-10-11 [1]
## timeDate 3043.102 2018-02-21 [1]
## usethis 1.4.0 2018-08-14 [1]
## withr 2.1.2 2018-03-15 [1]
## xfun 0.8 2019-06-25 [1]
## yaml 2.2.0 2018-07-25 [1]
## source
## CRAN (R 3.5.1)
## CRAN (R 3.5.1)
## CRAN (R 3.5.1)
## CRAN (R 3.5.1)
## CRAN (R 3.5.1)
## CRAN (R 3.5.1)
## CRAN (R 3.5.1)
## CRAN (R 3.5.1)
## CRAN (R 3.5.1)
## CRAN (R 3.5.1)
## CRAN (R 3.5.1)
## CRAN (R 3.5.1)
## CRAN (R 3.5.1)
## CRAN (R 3.5.1)
## CRAN (R 3.5.1)
## CRAN (R 3.5.1)
## CRAN (R 3.5.1)
## CRAN (R 3.5.1)
## CRAN (R 3.5.1)
## CRAN (R 3.5.1)
## CRAN (R 3.5.1)
## CRAN (R 3.5.1)
## Github (rmflight/fakeDataWithError@ccd8714)
## CRAN (R 3.5.1)
## Github (r-lib/fs@00e2de8)
## CRAN (R 3.5.1)
## CRAN (R 3.5.1)
## CRAN (R 3.5.1)
## CRAN (R 3.5.1)
## CRAN (R 3.5.1)
## CRAN (R 3.5.1)
## CRAN (R 3.5.1)
## CRAN (R 3.5.1)
## CRAN (R 3.5.1)
## CRAN (R 3.5.1)
## CRAN (R 3.5.1)
## CRAN (R 3.5.1)
## CRAN (R 3.5.1)
## CRAN (R 3.5.1)
## CRAN (R 3.5.1)
## CRAN (R 3.5.1)
## CRAN (R 3.5.1)
## CRAN (R 3.5.1)
## CRAN (R 3.5.1)
## CRAN (R 3.5.1)
## CRAN (R 3.5.1)
## CRAN (R 3.5.1)
## CRAN (R 3.5.1)
## CRAN (R 3.5.1)
## CRAN (R 3.5.1)
## CRAN (R 3.5.1)
## CRAN (R 3.5.1)
## CRAN (R 3.5.1)
## CRAN (R 3.5.1)
## CRAN (R 3.5.1)
## CRAN (R 3.5.1)
## CRAN (R 3.5.1)
## CRAN (R 3.5.1)
## CRAN (R 3.5.1)
## CRAN (R 3.5.1)
## CRAN (R 3.5.1)
## CRAN (R 3.5.1)
## CRAN (R 3.5.1)
## CRAN (R 3.5.1)
## CRAN (R 3.5.1)
## CRAN (R 3.5.1)
## CRAN (R 3.5.1)
## CRAN (R 3.5.1)
## CRAN (R 3.5.1)
## Github (r-lib/rlang@15e799c)
## CRAN (R 3.5.1)
## CRAN (R 3.5.1)
## CRAN (R 3.5.1)
## CRAN (R 3.5.1)
## CRAN (R 3.5.1)
## CRAN (R 3.5.1)
## CRAN (R 3.5.1)
## CRAN (R 3.5.1)
## CRAN (R 3.5.1)
## CRAN (R 3.5.1)
## CRAN (R 3.5.1)
## CRAN (R 3.5.1)
## CRAN (R 3.5.1)
## CRAN (R 3.5.1)
## CRAN (R 3.5.1)
## CRAN (R 3.5.1)
## CRAN (R 3.5.1)
## CRAN (R 3.5.1)
##
## [1] /software/R_libs/R351_bioc37
## [2] /software/R-351/lib/R/library