Random Forest vs PLS on Random Data

random-forest machine-learning partial-least-squares statistics analysis

Comparing random-forest and partial-least-squares discriminant-analysis on random data to show the problems inherent in PLS-DA.

Robert M Flight
2015-12-12

TL;DR

Partial least squares (PLS) discriminant-analysis (DA) can ridiculously over fit even on completely random data. The quality of the PLS-DA model can be assessed using cross-validation, but cross-validation is not typically performed in many metabolomics publications. Random forest, in contrast, because of the forest of decision tree learners, and the out-of-bag (OOB) samples used for testing each tree, automatically provides an indication of the quality of the model.

Why?

I’ve recently been working on some machine learning work using random forests (RF) Breimann, 2001 on metabolomics data. This has been relatively successful, with decent sensitivity and specificity, and hopefully I’ll be able to post more soon. However, PLS (Wold, 1975) is a standard technique used in metabolomics due to the prevalence of analytical chemists in metabolomics and a long familiarity with the method. Importantly, my collaborators frequently use PLS-DA to generate plots to show that the various classes of samples are separable.

However, it has long been known that PLS (and all of it’s variants, PLS-DA, OPLS, OPLS-DA, etc) can easily generate models that over fit the data, and that over fitting of the model needs to be assessed if the model is going to be used in subsequent analyses.

Random Data

To illustrate the behavior of both RF and PLS-DA, we will generate some random data where each of the samples are randomly assigned to one of two classes.

Feature Intensities

We will generate a data set with 1000 features, where each feature’s mean value is from a uniform distribution with a range of 1-10000.

library(ggplot2)
theme_set(cowplot::theme_cowplot())
library(fakeDataWithError)
set.seed(1234)
n_point <- 1000
max_value <- 10000
init_values <- runif(n_point, 0, max_value)
init_data <- data.frame(data = init_values)
ggplot(init_data, aes(x = data)) + geom_histogram() + ggtitle("Initial Data")

For each of these features, their distribution across samples will be based on a random normal distribution where the mean is the initial feature value and a standard deviation of 200. The number of samples is 100.

n_sample <- 100
error_values <- add_uniform_noise(n_sample, init_values, 200)

Just for information, the add_uniform_noise function is this:

add_uniform_noise
function (n_rep, value, sd, use_zero = FALSE) 
{
    n_value <- length(value)
    n_sd <- n_rep * n_value
    out_sd <- rnorm(n_sd, 0, sd)
    out_sd <- matrix(out_sd, nrow = n_value, ncol = n_rep)
    if (!use_zero) {
        tmp_value <- matrix(value, nrow = n_value, ncol = n_rep, 
            byrow = FALSE)
        out_value <- tmp_value + out_sd
    }
    else {
        out_value <- out_sd
    }
    return(out_value)
}
<bytecode: 0x55bf33f2aa58>
<environment: namespace:fakeDataWithError>

I created it as part of a package that is able to add different kinds of noise to data.

The distribution of values for a single feature looks like this:

error_data <- data.frame(feature_1 = error_values[1,])
ggplot(error_data, aes(x = feature_1)) + geom_histogram() + ggtitle("Error Data")

And we will assign the first 50 samples to class_1 and the second 50 samples to class_2.

sample_class <- rep(c("class_1", "class_2"), each = 50)
sample_class
  [1] "class_1" "class_1" "class_1" "class_1" "class_1" "class_1"
  [7] "class_1" "class_1" "class_1" "class_1" "class_1" "class_1"
 [13] "class_1" "class_1" "class_1" "class_1" "class_1" "class_1"
 [19] "class_1" "class_1" "class_1" "class_1" "class_1" "class_1"
 [25] "class_1" "class_1" "class_1" "class_1" "class_1" "class_1"
 [31] "class_1" "class_1" "class_1" "class_1" "class_1" "class_1"
 [37] "class_1" "class_1" "class_1" "class_1" "class_1" "class_1"
 [43] "class_1" "class_1" "class_1" "class_1" "class_1" "class_1"
 [49] "class_1" "class_1" "class_2" "class_2" "class_2" "class_2"
 [55] "class_2" "class_2" "class_2" "class_2" "class_2" "class_2"
 [61] "class_2" "class_2" "class_2" "class_2" "class_2" "class_2"
 [67] "class_2" "class_2" "class_2" "class_2" "class_2" "class_2"
 [73] "class_2" "class_2" "class_2" "class_2" "class_2" "class_2"
 [79] "class_2" "class_2" "class_2" "class_2" "class_2" "class_2"
 [85] "class_2" "class_2" "class_2" "class_2" "class_2" "class_2"
 [91] "class_2" "class_2" "class_2" "class_2" "class_2" "class_2"
 [97] "class_2" "class_2" "class_2" "class_2"

PCA

Just to show that the data is pretty random, lets use principal components analysis (PCA) to do a decomposition, and plot the first two components:

tmp_pca <- prcomp(t(error_values), center = TRUE, scale. = TRUE)
pca_data <- as.data.frame(tmp_pca$x[, 1:2])
pca_data$class <- as.factor(sample_class)
ggplot(pca_data, aes(x = PC1, y = PC2, color = class)) + geom_point(size = 4)

Random Forest

Let’s use RF first, and see how things look.

library(randomForest)
rf_model <- randomForest(t(error_values), y = as.factor(sample_class))

The confusion matrix comparing actual vs predicted classes based on the out of bag (OOB) samples:

knitr::kable(rf_model$confusion)
class_1 class_2 class.error
class_1 21 29 0.58
class_2 28 22 0.56

And an overall error of 0.5760364.

PLS-DA

So PLS-DA is really just PLS with y variable that is binary.

library(caret)
pls_model <- plsda(t(error_values), as.factor(sample_class), ncomp = 2)
pls_scores <- data.frame(comp1 = pls_model$scores[,1], comp2 = pls_model$scores[,2], class = sample_class)

And plot the PLS scores:

ggplot(pls_scores, aes(x = comp1, y = comp2, color = class)) + geom_point(size = 4) + ggtitle("PLS-DA of Random Data")

And voila! Perfectly separated data! If I didn’t tell you that it was random, would you suspect it?

Cross-validated PLS-DA

Of course, one way to truly assess the worth of the model would be to use cross-validation, where a fraction of data is held back, and the model trained on the rest. Predictions are then made on the held back fraction, and because we know the truth, we will then calculate the area under the reciever operator curve (AUROC) or area under the curve (AUC) created by plotting true positives vs false positives.

To do this we will need two functions:

  1. Generates all of the CV folds
  2. Generates PLS-DA model, does prediction on hold out, calculates AUC
library(cvTools)
library(ROCR)

gen_cv <- function(xdata, ydata, nrep, kfold){
  n_sample <- length(ydata)
  all_index <- seq(1, n_sample)
  cv_data <- cvFolds(n_sample, K = kfold, R = nrep, type = "random")
  
  rep_values <- vapply(seq(1, nrep), function(in_rep){
    use_rep <- cv_data$subsets[, in_rep]
    cv_values <- vapply(seq(1, kfold), function(in_fold){
      test_index <- use_rep[cv_data$which == in_fold]
      train_index <- all_index[-test_index]
      
      plsda_cv(xdata[train_index, ], ydata[train_index], xdata[test_index, ],
               ydata[test_index])
    }, numeric(1))
  }, numeric(kfold))
}

plsda_cv <- function(xtrain, ytrain, xtest, ytest){
  pls_model <- plsda(xtrain, ytrain, ncomp = 2)
  pls_pred <- predict(pls_model, xtest, type = "prob")
  
  use_pred <- pls_pred[, 2, 1]
  
  pred_perf <- ROCR::prediction(use_pred, ytest)
  pred_auc <- ROCR::performance(pred_perf, "auc")@y.values[[1]]
  return(pred_auc)
}

And now lets do a bunch of replicates (100).

cv_vals <- gen_cv(t(error_values), factor(sample_class), nrep = 100, kfold = 5)

mean(cv_vals)
[1] 0.4260387
sd(cv_vals)
[1] 0.1188491
cv_frame <- data.frame(auc = as.vector(cv_vals))
ggplot(cv_frame, aes(x = auc)) + geom_histogram(binwidth = 0.01)

So we get an average AUC of 0.4260387, which is pretty awful. This implies that even though there was good separation on the scores, maybe the model is not actually that good, and we should be cautious of any predictions being made.

Of course, the PCA at the beginning of the analysis shows that there is no real separation in the data in the first place.

devtools::session_info()
─ Session info ─────────────────────────────────────────────────────
 setting  value                       
 version  R version 4.0.0 (2020-04-24)
 os       Pop!_OS 20.04 LTS           
 system   x86_64, linux-gnu           
 ui       X11                         
 language en_US:en                    
 collate  en_US.UTF-8                 
 ctype    en_US.UTF-8                 
 tz       America/New_York            
 date     2021-02-27                  

─ Packages ─────────────────────────────────────────────────────────
 package           * version    date       lib source        
 assertthat          0.2.1      2019-03-21 [1] CRAN (R 4.0.0)
 cachem              1.0.4      2021-02-13 [1] CRAN (R 4.0.0)
 callr               3.5.1      2020-10-13 [1] CRAN (R 4.0.0)
 caret             * 6.0-86     2020-03-20 [1] CRAN (R 4.0.0)
 class               7.3-18     2021-01-24 [1] CRAN (R 4.0.0)
 cli                 2.3.0      2021-01-31 [1] CRAN (R 4.0.0)
 codetools           0.2-18     2020-11-04 [1] CRAN (R 4.0.0)
 colorspace          2.0-0      2020-11-11 [1] CRAN (R 4.0.0)
 cowplot             1.1.1      2020-12-30 [1] CRAN (R 4.0.0)
 crayon              1.4.1      2021-02-08 [1] CRAN (R 4.0.0)
 cvTools           * 0.3.2      2012-05-14 [1] CRAN (R 4.0.0)
 data.table          1.13.6     2020-12-30 [1] CRAN (R 4.0.0)
 DBI                 1.1.1      2021-01-15 [1] CRAN (R 4.0.0)
 DEoptimR            1.0-8      2016-11-19 [1] CRAN (R 4.0.0)
 desc                1.2.0      2018-05-01 [1] CRAN (R 4.0.0)
 devtools            2.3.2      2020-09-18 [1] CRAN (R 4.0.0)
 digest              0.6.27     2020-10-24 [1] CRAN (R 4.0.0)
 distill             1.2        2021-01-13 [1] CRAN (R 4.0.0)
 downlit             0.2.1      2020-11-04 [1] CRAN (R 4.0.0)
 dplyr               1.0.4      2021-02-02 [1] CRAN (R 4.0.0)
 ellipsis            0.3.1      2020-05-15 [1] CRAN (R 4.0.0)
 evaluate            0.14       2019-05-28 [1] CRAN (R 4.0.0)
 fakeDataWithError * 0.0.1      2020-05-27 [1] local         
 fansi               0.4.2      2021-01-15 [1] CRAN (R 4.0.0)
 farver              2.0.3      2020-01-16 [1] CRAN (R 4.0.0)
 fastmap             1.1.0      2021-01-25 [1] CRAN (R 4.0.0)
 foreach             1.5.1      2020-10-15 [1] CRAN (R 4.0.0)
 fs                  1.5.0      2020-07-31 [1] CRAN (R 4.0.0)
 generics            0.1.0      2020-10-31 [1] CRAN (R 4.0.0)
 ggplot2           * 3.3.3      2020-12-30 [1] CRAN (R 4.0.0)
 glue                1.4.2      2020-08-27 [1] CRAN (R 4.0.0)
 gower               0.2.2      2020-06-23 [1] CRAN (R 4.0.0)
 gtable              0.3.0      2019-03-25 [1] CRAN (R 4.0.0)
 highr               0.8        2019-03-20 [1] CRAN (R 4.0.0)
 htmltools           0.5.1.1    2021-01-22 [1] CRAN (R 4.0.0)
 ipred               0.9-9      2019-04-28 [1] CRAN (R 4.0.0)
 iterators           1.0.13     2020-10-15 [1] CRAN (R 4.0.0)
 knitr               1.31       2021-01-27 [1] CRAN (R 4.0.0)
 labeling            0.4.2      2020-10-20 [1] CRAN (R 4.0.0)
 lattice           * 0.20-41    2020-04-02 [1] CRAN (R 4.0.0)
 lava                1.6.8.1    2020-11-04 [1] CRAN (R 4.0.0)
 lifecycle           1.0.0      2021-02-15 [1] CRAN (R 4.0.0)
 lubridate           1.7.9.2    2020-11-13 [1] CRAN (R 4.0.0)
 magrittr            2.0.1      2020-11-17 [1] CRAN (R 4.0.0)
 MASS                7.3-53.1   2021-02-12 [1] CRAN (R 4.0.0)
 Matrix              1.3-2      2021-01-06 [1] CRAN (R 4.0.0)
 memoise             2.0.0      2021-01-26 [1] CRAN (R 4.0.0)
 ModelMetrics        1.2.2.2    2020-03-17 [1] CRAN (R 4.0.0)
 munsell             0.5.0      2018-06-12 [1] CRAN (R 4.0.0)
 nlme                3.1-152    2021-02-04 [1] CRAN (R 4.0.0)
 nnet                7.3-15     2021-01-24 [1] CRAN (R 4.0.0)
 pillar              1.4.7      2020-11-20 [1] CRAN (R 4.0.0)
 pkgbuild            1.2.0      2020-12-15 [1] CRAN (R 4.0.0)
 pkgconfig           2.0.3      2019-09-22 [1] CRAN (R 4.0.0)
 pkgload             1.1.0      2020-05-29 [1] CRAN (R 4.0.0)
 pls                 2.7-3      2020-08-07 [1] CRAN (R 4.0.0)
 plyr                1.8.6      2020-03-03 [1] CRAN (R 4.0.0)
 prettyunits         1.1.1      2020-01-24 [1] CRAN (R 4.0.0)
 pROC                1.17.0.1   2021-01-13 [1] CRAN (R 4.0.0)
 processx            3.4.5      2020-11-30 [1] CRAN (R 4.0.0)
 prodlim             2019.11.13 2019-11-17 [1] CRAN (R 4.0.0)
 ps                  1.5.0      2020-12-05 [1] CRAN (R 4.0.0)
 purrr               0.3.4      2020-04-17 [1] CRAN (R 4.0.0)
 R6                  2.5.0      2020-10-28 [1] CRAN (R 4.0.0)
 randomForest      * 4.6-14     2018-03-25 [1] CRAN (R 4.0.0)
 Rcpp                1.0.6      2021-01-15 [1] CRAN (R 4.0.0)
 recipes             0.1.15     2020-11-11 [1] CRAN (R 4.0.0)
 remotes             2.2.0      2020-07-21 [1] CRAN (R 4.0.0)
 reshape2            1.4.4      2020-04-09 [1] CRAN (R 4.0.0)
 rlang               0.4.10     2020-12-30 [1] CRAN (R 4.0.0)
 rmarkdown           2.6        2020-12-14 [1] CRAN (R 4.0.0)
 robustbase        * 0.93-7     2021-01-04 [1] CRAN (R 4.0.0)
 ROCR              * 1.0-11     2020-05-02 [1] CRAN (R 4.0.0)
 rpart               4.1-15     2019-04-12 [1] CRAN (R 4.0.0)
 rprojroot           2.0.2      2020-11-15 [1] CRAN (R 4.0.0)
 scales              1.1.1      2020-05-11 [1] CRAN (R 4.0.0)
 sessioninfo         1.1.1      2018-11-05 [1] CRAN (R 4.0.0)
 stringi             1.5.3      2020-09-09 [1] CRAN (R 4.0.0)
 stringr             1.4.0      2019-02-10 [1] CRAN (R 4.0.0)
 survival            3.2-7      2020-09-28 [1] CRAN (R 4.0.0)
 testthat            3.0.2      2021-02-14 [1] CRAN (R 4.0.0)
 tibble              3.0.6      2021-01-29 [1] CRAN (R 4.0.0)
 tidyselect          1.1.0      2020-05-11 [1] CRAN (R 4.0.0)
 timeDate            3043.102   2018-02-21 [1] CRAN (R 4.0.0)
 usethis             2.0.1      2021-02-10 [1] CRAN (R 4.0.0)
 vctrs               0.3.6      2020-12-17 [1] CRAN (R 4.0.0)
 withr               2.4.1      2021-01-26 [1] CRAN (R 4.0.0)
 xfun                0.21       2021-02-10 [1] CRAN (R 4.0.0)
 yaml                2.2.1      2020-02-01 [1] CRAN (R 4.0.0)

[1] /software/R_libs/R400
[2] /software/R-4.0.0/lib/R/library

Corrections

If you see mistakes or want to suggest changes, please create an issue on the source repository.

Reuse

Text and figures are licensed under Creative Commons Attribution CC BY 4.0. Source code is available at https://github.com/rmflight/researchBlog_distill, unless otherwise noted. The figures that have been reused from other sources don't fall under this license and can be recognized by a note in their caption: "Figure from ...".

Citation

For attribution, please cite this work as

Flight (2015, Dec. 12). Deciphering Life: One Bit at a Time: Random Forest vs PLS on Random Data. Retrieved from https://rmflight.github.io/posts/2015-12-12-random-forest-vs-pls-on-random-data/

BibTeX citation

@misc{flight2015random,
  author = {Flight, Robert M},
  title = {Deciphering Life: One Bit at a Time: Random Forest vs PLS on Random Data},
  url = {https://rmflight.github.io/posts/2015-12-12-random-forest-vs-pls-on-random-data/},
  year = {2015}
}