diff --git a/paper/config.json b/config/config.json similarity index 81% rename from paper/config.json rename to config/config.json index 4f1dd0d..29cd68e 100644 --- a/paper/config.json +++ b/config/config.json @@ -1,5 +1,8 @@ { "random_state": 42, + "datasets_last": [ + "BLCA" + ], "datasets": [ "BLCA", "BRCA", @@ -12,8 +15,9 @@ "OV", "STAD" ], - "datasets_last": ["LUSC"], - "datasets_lel": ["BLCA"], + "datasets_lel": [ + "BLCA" + ], "n_outer_splits": 5, "n_outer_repetitions": 25, "pc_n_components": 16, @@ -25,7 +29,15 @@ "seed": 42, "shuffle_cv": true, "timing_reps": 5, - "l1_ratio_tuned": [0.1, 0.5, 0.7, 0.9, 0.95, 0.99, 1], + "l1_ratio_tuned": [ + 0.1, + 0.5, + 0.7, + 0.9, + 0.95, + 0.99, + 1 + ], "pc_n_components_tuned": [ 8, 16, @@ -66,7 +78,9 @@ 0.4, 0.6 ], - "tune_batch_size": [1024], + "tune_batch_size": [ + 1024 + ], "random_search_n_iter": 50, "error_score": 100 } diff --git a/config/config.yaml b/config/config.yaml new file mode 100644 index 0000000..ac2b46e --- /dev/null +++ b/config/config.yaml @@ -0,0 +1,8 @@ +snakemake_min_version: "7.30.1" +mambaforge_version: "23.3.1-0" +config_path: "config/config.json" +datasets: + ["BLCA", "BRCA", "HNSC", "KIRC", "LGG", "LIHC", "LUAD", "LUSC", "OV", "STAD"] +random_seed: 42 +n_outer_repetitions: 25 +n_outer_splits: 5 diff --git a/paper/.Rprofile b/paper/.Rprofile deleted file mode 100644 index 81b960f..0000000 --- a/paper/.Rprofile +++ /dev/null @@ -1 +0,0 @@ -source("renv/activate.R") diff --git a/paper/README.md b/paper/README.md deleted file mode 100644 index b32e59c..0000000 --- a/paper/README.md +++ /dev/null @@ -1,46 +0,0 @@ -# *sparsesurv*: Sparse survival models via knowledge distillation -## Abstract -Sparse survival models are statistical models that select a subset of predictor variables while modeling the time until an event occurs, which can subsequently help interpretability and transportability. The subset of important features is typically obtained with regularized models, such as the Cox Proportional Hazards model with Lasso regularization, which limit the number of non-zero coefficients. However, such models can be sensitive to the choice of regularization hyperparameter. In this work, we demonstrate how knowledge distillation, a powerful technique in machine learning that aims to transfer knowledge from a complex teacher model to a simpler student model, can be leveraged to learn sparse models while mitigating the challenge above. We present sparsesurv, a Python package that contains a set of teacher-student model pairs, including the semi-parametric accelerated failure time and the extended hazards models as teachers, which currently do not have Python implementations. It also contains in-house survival function estimators, removing the need for external packages. Sparsesurv is validated against R-based Elastic Net regularized linear Cox proportional hazards models as implemented in the commonly used *glmnet* package. Our results reveal that knowledge distillation-based approaches achieve better discriminative performance across the regularization path while making the choice of the regularization hyperparameter significantly easier. All of these features, combined with an *sklearn*-like API, make sparsesurv an easy-to-use Python package that enables survival analysis for high-dimensional datasets and allows fitting sparse survival models via knowledge distillation. - -## Reproducibility -### From scratch -Since installing via conda and R in bash scripts can be finicky and anyway requires user input, we guide you through the process below, after which you may reproduce all of our results by executing our reproduction script. - -#### Python -Please run the following in a terminal and give user input as appropriate (e.g., confirming that you want to create a new conda env). - -```sh -conda create -n sparsesurv_paper python==3.10.0 -conda activate sparsesurv_paper -pip install -r requirements.txt -pip install -e .. -``` - -#### R -We require you to have R 4.2.2 installed - we recommend to use [Rig](https://github.com/r-lib/rig) to manage different R versions. - -Supposing you already have R 4.2.2, you may simply run the below in a terminal. - -```sh -Rscript -e "install.packages('renv');require(renv);renv::activate();renv::restore()" -``` - -#### Running experiments -Once both the necessary R and Python packages are installed, you may reproduce all of our work (including data downloads, preprocessing, etc) by running the below in a terminal (make sure to activate the respective conda environment if it is not already active). - -```sh -bash reproduce.sh -``` - -### Results -All of our results, including preprocessed data, computed performance metrics and predicted survival functions for all models and experiments are available on [Zenodo](https://zenodo.org/records/11058330). - -## Questions -In case of any questions, please reach out to david.wissel@inf.ethz.ch or open an issue in this repo. - -## Citation -Our manuscript is still under review. - -## References -[1] Hinton, Geoffrey, Oriol Vinyals, and Jeff Dean. "Distilling the knowledge in a neural network." arXiv preprint arXiv:1503.02531 (2015). -[2] Paul, Debashis, et al. "" Preconditioning" for Feature Selection and Regression in High-Dimensional Problems." The Annals of Statistics (2008): 1595-1618. diff --git a/paper/paper.Rproj b/paper/paper.Rproj deleted file mode 100644 index 8e3c2eb..0000000 --- a/paper/paper.Rproj +++ /dev/null @@ -1,13 +0,0 @@ -Version: 1.0 - -RestoreWorkspace: Default -SaveWorkspace: Default -AlwaysSaveHistory: Default - -EnableCodeIndexing: Yes -UseSpacesForTab: Yes -NumSpacesForTab: 2 -Encoding: UTF-8 - -RnwWeave: Sweave -LaTeX: pdfLaTeX diff --git a/paper/renv.lock b/paper/renv.lock deleted file mode 100644 index fc27639..0000000 --- a/paper/renv.lock +++ /dev/null @@ -1,2191 +0,0 @@ -{ - "R": { - "Version": "4.2.2", - "Repositories": [ - { - "Name": "CRAN", - "URL": "https://cloud.r-project.org" - } - ] - }, - "Packages": { - "Formula": { - "Package": "Formula", - "Version": "1.2-5", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "stats" - ], - "Hash": "7a29697b75e027767a53fde6c903eca7" - }, - "Hmisc": { - "Package": "Hmisc", - "Version": "5.1-1", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "Formula", - "base64enc", - "cluster", - "colorspace", - "data.table", - "foreign", - "ggplot2", - "grid", - "gridExtra", - "gtable", - "htmlTable", - "htmltools", - "knitr", - "methods", - "nnet", - "rmarkdown", - "rpart", - "viridis" - ], - "Hash": "27c7750992bc511728e2f9ebd2911199" - }, - "KernSmooth": { - "Package": "KernSmooth", - "Version": "2.23-20", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "stats" - ], - "Hash": "8dcfa99b14c296bc9f1fd64d52fd3ce7" - }, - "MASS": { - "Package": "MASS", - "Version": "7.3-58.1", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "grDevices", - "graphics", - "methods", - "stats", - "utils" - ], - "Hash": "762e1804143a332333c054759f89a706" - }, - "Matrix": { - "Package": "Matrix", - "Version": "1.6-5", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "grDevices", - "graphics", - "grid", - "lattice", - "methods", - "stats", - "utils" - ], - "Hash": "8c7115cd3a0e048bda2a7cd110549f7a" - }, - "MatrixModels": { - "Package": "MatrixModels", - "Version": "0.5-3", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "Matrix", - "R", - "methods", - "stats" - ], - "Hash": "0776bf7526869e0286b0463cb72fb211" - }, - "Publish": { - "Package": "Publish", - "Version": "2023.01.17", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "data.table", - "lava", - "multcomp", - "prodlim", - "survival" - ], - "Hash": "2f0c7247e7d173efe2b022f8053c8de6" - }, - "R6": { - "Package": "R6", - "Version": "2.5.1", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R" - ], - "Hash": "470851b6d5d0ac559e9d01bb352b4021" - }, - "RColorBrewer": { - "Package": "RColorBrewer", - "Version": "1.1-3", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R" - ], - "Hash": "45f0398006e83a5b10b72a90663d8d8c" - }, - "Rcpp": { - "Package": "Rcpp", - "Version": "1.0.12", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "methods", - "utils" - ], - "Hash": "5ea2700d21e038ace58269ecdbeb9ec0" - }, - "RcppArmadillo": { - "Package": "RcppArmadillo", - "Version": "0.12.6.6.1", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "Rcpp", - "methods", - "stats", - "utils" - ], - "Hash": "d2b60e0a15d73182a3a766ff0a7d0d7f" - }, - "RcppEigen": { - "Package": "RcppEigen", - "Version": "0.3.3.9.4", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "Rcpp", - "stats", - "utils" - ], - "Hash": "acb0a5bf38490f26ab8661b467f4f53a" - }, - "SQUAREM": { - "Package": "SQUAREM", - "Version": "2021.1", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R" - ], - "Hash": "0cf10dab0d023d5b46a5a14387556891" - }, - "SparseM": { - "Package": "SparseM", - "Version": "1.81", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "graphics", - "methods", - "stats", - "utils" - ], - "Hash": "2042cd9759cc89a453c4aefef0ce9aae" - }, - "TH.data": { - "Package": "TH.data", - "Version": "1.1-2", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "MASS", - "R", - "survival" - ], - "Hash": "5b250ad4c5863ee4a68e280fcb0a3600" - }, - "askpass": { - "Package": "askpass", - "Version": "1.2.0", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "sys" - ], - "Hash": "cad6cf7f1d5f6e906700b9d3e718c796" - }, - "assertthat": { - "Package": "assertthat", - "Version": "0.2.1", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "tools" - ], - "Hash": "50c838a310445e954bc13f26f26a6ecf" - }, - "backports": { - "Package": "backports", - "Version": "1.4.1", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R" - ], - "Hash": "c39fbec8a30d23e721980b8afb31984c" - }, - "base64enc": { - "Package": "base64enc", - "Version": "0.1-3", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R" - ], - "Hash": "543776ae6848fde2f48ff3816d0628bc" - }, - "bit": { - "Package": "bit", - "Version": "4.0.5", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R" - ], - "Hash": "d242abec29412ce988848d0294b208fd" - }, - "bit64": { - "Package": "bit64", - "Version": "4.0.5", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "bit", - "methods", - "stats", - "utils" - ], - "Hash": "9fe98599ca456d6552421db0d6772d8f" - }, - "bslib": { - "Package": "bslib", - "Version": "0.6.1", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "base64enc", - "cachem", - "grDevices", - "htmltools", - "jquerylib", - "jsonlite", - "lifecycle", - "memoise", - "mime", - "rlang", - "sass" - ], - "Hash": "c0d8599494bc7fb408cd206bbdd9cab0" - }, - "cachem": { - "Package": "cachem", - "Version": "1.0.8", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "fastmap", - "rlang" - ], - "Hash": "c35768291560ce302c0a6589f92e837d" - }, - "cellranger": { - "Package": "cellranger", - "Version": "1.1.0", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "rematch", - "tibble" - ], - "Hash": "f61dbaec772ccd2e17705c1e872e9e7c" - }, - "checkmate": { - "Package": "checkmate", - "Version": "2.3.1", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "backports", - "utils" - ], - "Hash": "c01cab1cb0f9125211a6fc99d540e315" - }, - "cli": { - "Package": "cli", - "Version": "3.6.2", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "utils" - ], - "Hash": "1216ac65ac55ec0058a6f75d7ca0fd52" - }, - "clipr": { - "Package": "clipr", - "Version": "0.8.0", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "utils" - ], - "Hash": "3f038e5ac7f41d4ac41ce658c85e3042" - }, - "cluster": { - "Package": "cluster", - "Version": "2.1.4", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "grDevices", - "graphics", - "stats", - "utils" - ], - "Hash": "5edbbabab6ce0bf7900a74fd4358628e" - }, - "cmprsk": { - "Package": "cmprsk", - "Version": "2.2-11", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "survival" - ], - "Hash": "677e2fde792ef6737926f92b9cb804e6" - }, - "codetools": { - "Package": "codetools", - "Version": "0.2-18", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R" - ], - "Hash": "019388fc48e48b3da0d3a76ff94608a8" - }, - "coefplot": { - "Package": "coefplot", - "Version": "1.2.8", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "dplyr", - "dygraphs", - "ggplot2", - "magrittr", - "plotly", - "plyr", - "purrr", - "reshape2", - "stats", - "tibble", - "useful" - ], - "Hash": "08c3560669c50f0a7f2b9e1524860ea2" - }, - "colorspace": { - "Package": "colorspace", - "Version": "2.1-0", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "grDevices", - "graphics", - "methods", - "stats" - ], - "Hash": "f20c47fd52fae58b4e377c37bb8c335b" - }, - "cowplot": { - "Package": "cowplot", - "Version": "1.1.2", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "ggplot2", - "grDevices", - "grid", - "gtable", - "methods", - "rlang", - "scales" - ], - "Hash": "ef28211987921217c61b4f4068068dac" - }, - "cpp11": { - "Package": "cpp11", - "Version": "0.4.7", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R" - ], - "Hash": "5a295d7d963cc5035284dcdbaf334f4e" - }, - "crayon": { - "Package": "crayon", - "Version": "1.5.2", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "grDevices", - "methods", - "utils" - ], - "Hash": "e8a1e41acf02548751f45c718d55aa6a" - }, - "crosstalk": { - "Package": "crosstalk", - "Version": "1.2.1", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R6", - "htmltools", - "jsonlite", - "lazyeval" - ], - "Hash": "ab12c7b080a57475248a30f4db6298c0" - }, - "curl": { - "Package": "curl", - "Version": "5.2.0", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R" - ], - "Hash": "ce88d13c0b10fe88a37d9c59dba2d7f9" - }, - "data.table": { - "Package": "data.table", - "Version": "1.14.10", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "methods" - ], - "Hash": "6ea17a32294d8ca00455825ab0cf71b9" - }, - "diagram": { - "Package": "diagram", - "Version": "1.6.5", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "graphics", - "shape", - "stats" - ], - "Hash": "c7f527c59edc72c4bce63519b8d38752" - }, - "digest": { - "Package": "digest", - "Version": "0.6.34", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "utils" - ], - "Hash": "7ede2ee9ea8d3edbf1ca84c1e333ad1a" - }, - "doParallel": { - "Package": "doParallel", - "Version": "1.0.17", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "foreach", - "iterators", - "parallel", - "utils" - ], - "Hash": "451e5edf411987991ab6a5410c45011f" - }, - "dplyr": { - "Package": "dplyr", - "Version": "1.1.4", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "R6", - "cli", - "generics", - "glue", - "lifecycle", - "magrittr", - "methods", - "pillar", - "rlang", - "tibble", - "tidyselect", - "utils", - "vctrs" - ], - "Hash": "fedd9d00c2944ff00a0e2696ccf048ec" - }, - "dygraphs": { - "Package": "dygraphs", - "Version": "1.1.1.6", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "htmltools", - "htmlwidgets", - "magrittr", - "xts", - "zoo" - ], - "Hash": "716869fffc16e282c118f8894e082a7d" - }, - "ellipsis": { - "Package": "ellipsis", - "Version": "0.3.2", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "rlang" - ], - "Hash": "bb0eec2fe32e88d9e2836c2f73ea2077" - }, - "evaluate": { - "Package": "evaluate", - "Version": "0.23", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "methods" - ], - "Hash": "daf4a1246be12c1fa8c7705a0935c1a0" - }, - "fansi": { - "Package": "fansi", - "Version": "1.0.6", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "grDevices", - "utils" - ], - "Hash": "962174cf2aeb5b9eea581522286a911f" - }, - "farver": { - "Package": "farver", - "Version": "2.1.1", - "Source": "Repository", - "Repository": "CRAN", - "Hash": "8106d78941f34855c440ddb946b8f7a5" - }, - "fastDummies": { - "Package": "fastDummies", - "Version": "1.7.3", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "data.table", - "stringr", - "tibble" - ], - "Hash": "e0f9c0c051e0e8d89996d7f0c400539f" - }, - "fastmap": { - "Package": "fastmap", - "Version": "1.1.1", - "Source": "Repository", - "Repository": "CRAN", - "Hash": "f7736a18de97dea803bde0a2daaafb27" - }, - "fontawesome": { - "Package": "fontawesome", - "Version": "0.5.2", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "htmltools", - "rlang" - ], - "Hash": "c2efdd5f0bcd1ea861c2d4e2a883a67d" - }, - "forcats": { - "Package": "forcats", - "Version": "1.0.0", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "cli", - "glue", - "lifecycle", - "magrittr", - "rlang", - "tibble" - ], - "Hash": "1a0a9a3d5083d0d573c4214576f1e690" - }, - "foreach": { - "Package": "foreach", - "Version": "1.5.2", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "codetools", - "iterators", - "utils" - ], - "Hash": "618609b42c9406731ead03adf5379850" - }, - "foreign": { - "Package": "foreign", - "Version": "0.8-83", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "methods", - "stats", - "utils" - ], - "Hash": "4e43a8846712a6d5991b19b9bd5f9199" - }, - "fs": { - "Package": "fs", - "Version": "1.6.3", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "methods" - ], - "Hash": "47b5f30c720c23999b913a1a635cf0bb" - }, - "future": { - "Package": "future", - "Version": "1.33.1", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "digest", - "globals", - "listenv", - "parallel", - "parallelly", - "utils" - ], - "Hash": "e57e292737f7a4efa9d8a91c5908222c" - }, - "future.apply": { - "Package": "future.apply", - "Version": "1.11.1", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "future", - "globals", - "parallel", - "utils" - ], - "Hash": "455e00c16ec193c8edcf1b2b522b3288" - }, - "generics": { - "Package": "generics", - "Version": "0.1.3", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "methods" - ], - "Hash": "15e9634c0fcd294799e9b2e929ed1b86" - }, - "ggplot2": { - "Package": "ggplot2", - "Version": "3.4.4", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "MASS", - "R", - "cli", - "glue", - "grDevices", - "grid", - "gtable", - "isoband", - "lifecycle", - "mgcv", - "rlang", - "scales", - "stats", - "tibble", - "vctrs", - "withr" - ], - "Hash": "313d31eff2274ecf4c1d3581db7241f9" - }, - "ggpubfigs": { - "Package": "ggpubfigs", - "Version": "0.0.1", - "Source": "GitHub", - "RemoteType": "github", - "RemoteHost": "api.github.com", - "RemoteUsername": "JLSteenwyk", - "RemoteRepo": "ggpubfigs", - "RemoteRef": "master", - "RemoteSha": "eb78b11264c3b518621878f8df8db95b4e78f094", - "Requirements": [ - "ggplot2" - ], - "Hash": "fca860a6fd7fdb0d6e929f1f7726788c" - }, - "ggsignif": { - "Package": "ggsignif", - "Version": "0.6.4", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "ggplot2" - ], - "Hash": "a57f0f5dbcfd0d77ad4ff33032f5dc79" - }, - "glmnet": { - "Package": "glmnet", - "Version": "4.1-8", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "Matrix", - "R", - "Rcpp", - "RcppEigen", - "foreach", - "methods", - "shape", - "survival", - "utils" - ], - "Hash": "eb6fc70e561aae41d5911a6726188f71" - }, - "glmnetUtils": { - "Package": "glmnetUtils", - "Version": "1.1.9", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "Matrix", - "glmnet", - "grDevices", - "graphics", - "parallel", - "stats" - ], - "Hash": "66165824c789439eff327118c90e844d" - }, - "globals": { - "Package": "globals", - "Version": "0.16.2", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "codetools" - ], - "Hash": "baa9585ab4ce47a9f4618e671778cc6f" - }, - "glue": { - "Package": "glue", - "Version": "1.7.0", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "methods" - ], - "Hash": "e0b3a53876554bd45879e596cdb10a52" - }, - "gridExtra": { - "Package": "gridExtra", - "Version": "2.3", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "grDevices", - "graphics", - "grid", - "gtable", - "utils" - ], - "Hash": "7d7f283939f563670a697165b2cf5560" - }, - "gtable": { - "Package": "gtable", - "Version": "0.3.4", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "cli", - "glue", - "grid", - "lifecycle", - "rlang" - ], - "Hash": "b29cf3031f49b04ab9c852c912547eef" - }, - "here": { - "Package": "here", - "Version": "1.0.1", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "rprojroot" - ], - "Hash": "24b224366f9c2e7534d2344d10d59211" - }, - "highr": { - "Package": "highr", - "Version": "0.10", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "xfun" - ], - "Hash": "06230136b2d2b9ba5805e1963fa6e890" - }, - "hms": { - "Package": "hms", - "Version": "1.1.3", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "lifecycle", - "methods", - "pkgconfig", - "rlang", - "vctrs" - ], - "Hash": "b59377caa7ed00fa41808342002138f9" - }, - "htmlTable": { - "Package": "htmlTable", - "Version": "2.4.2", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "checkmate", - "htmltools", - "htmlwidgets", - "knitr", - "magrittr", - "methods", - "rstudioapi", - "stringr" - ], - "Hash": "0164d8cade33fac2190703da7e6e3241" - }, - "htmltools": { - "Package": "htmltools", - "Version": "0.5.7", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "base64enc", - "digest", - "ellipsis", - "fastmap", - "grDevices", - "rlang", - "utils" - ], - "Hash": "2d7b3857980e0e0d0a1fd6f11928ab0f" - }, - "htmlwidgets": { - "Package": "htmlwidgets", - "Version": "1.6.4", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "grDevices", - "htmltools", - "jsonlite", - "knitr", - "rmarkdown", - "yaml" - ], - "Hash": "04291cc45198225444a397606810ac37" - }, - "httr": { - "Package": "httr", - "Version": "1.4.7", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "R6", - "curl", - "jsonlite", - "mime", - "openssl" - ], - "Hash": "ac107251d9d9fd72f0ca8049988f1d7f" - }, - "irlba": { - "Package": "irlba", - "Version": "2.3.5.1", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "Matrix", - "R", - "methods", - "stats" - ], - "Hash": "acb06a47b732c6251afd16e19c3201ff" - }, - "isoband": { - "Package": "isoband", - "Version": "0.2.7", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "grid", - "utils" - ], - "Hash": "0080607b4a1a7b28979aecef976d8bc2" - }, - "iterators": { - "Package": "iterators", - "Version": "1.0.14", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "utils" - ], - "Hash": "8954069286b4b2b0d023d1b288dce978" - }, - "janitor": { - "Package": "janitor", - "Version": "2.2.0", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "dplyr", - "hms", - "lifecycle", - "lubridate", - "magrittr", - "purrr", - "rlang", - "snakecase", - "stringi", - "stringr", - "tidyr", - "tidyselect" - ], - "Hash": "5baae149f1082f466df9d1442ba7aa65" - }, - "jquerylib": { - "Package": "jquerylib", - "Version": "0.1.4", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "htmltools" - ], - "Hash": "5aab57a3bd297eee1c1d862735972182" - }, - "jsonlite": { - "Package": "jsonlite", - "Version": "1.8.8", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "methods" - ], - "Hash": "e1b9c55281c5adc4dd113652d9e26768" - }, - "knitr": { - "Package": "knitr", - "Version": "1.45", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "evaluate", - "highr", - "methods", - "tools", - "xfun", - "yaml" - ], - "Hash": "1ec462871063897135c1bcbe0fc8f07d" - }, - "labeling": { - "Package": "labeling", - "Version": "0.4.3", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "graphics", - "stats" - ], - "Hash": "b64ec208ac5bc1852b285f665d6368b3" - }, - "later": { - "Package": "later", - "Version": "1.3.2", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "Rcpp", - "rlang" - ], - "Hash": "a3e051d405326b8b0012377434c62b37" - }, - "lattice": { - "Package": "lattice", - "Version": "0.20-45", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "grDevices", - "graphics", - "grid", - "stats", - "utils" - ], - "Hash": "b64cdbb2b340437c4ee047a1f4c4377b" - }, - "lava": { - "Package": "lava", - "Version": "1.7.3", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "SQUAREM", - "future.apply", - "grDevices", - "graphics", - "methods", - "numDeriv", - "progressr", - "stats", - "survival", - "utils" - ], - "Hash": "975f46623ba2e2c059fc959e8bee92b8" - }, - "lazyeval": { - "Package": "lazyeval", - "Version": "0.2.2", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R" - ], - "Hash": "d908914ae53b04d4c0c0fd72ecc35370" - }, - "lifecycle": { - "Package": "lifecycle", - "Version": "1.0.4", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "cli", - "glue", - "rlang" - ], - "Hash": "b8552d117e1b808b09a832f589b79035" - }, - "listenv": { - "Package": "listenv", - "Version": "0.9.0", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R" - ], - "Hash": "4fbd3679ec8ee169ba28d4b1ea7d0e8f" - }, - "lubridate": { - "Package": "lubridate", - "Version": "1.9.3", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "generics", - "methods", - "timechange" - ], - "Hash": "680ad542fbcf801442c83a6ac5a2126c" - }, - "magrittr": { - "Package": "magrittr", - "Version": "2.0.3", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R" - ], - "Hash": "7ce2733a9826b3aeb1775d56fd305472" - }, - "memoise": { - "Package": "memoise", - "Version": "2.0.1", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "cachem", - "rlang" - ], - "Hash": "e2817ccf4a065c5d9d7f2cfbe7c1d78c" - }, - "mets": { - "Package": "mets", - "Version": "1.3.3", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "Rcpp", - "RcppArmadillo", - "compiler", - "lava", - "mvtnorm", - "numDeriv", - "splines", - "survival", - "timereg" - ], - "Hash": "a40f717758d8dfb7fcb3b5fa457b17b1" - }, - "mgcv": { - "Package": "mgcv", - "Version": "1.8-41", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "Matrix", - "R", - "graphics", - "methods", - "nlme", - "splines", - "stats", - "utils" - ], - "Hash": "6b3904f13346742caa3e82dd0303d4ad" - }, - "microbenchmark": { - "Package": "microbenchmark", - "Version": "1.4.10", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "graphics", - "stats" - ], - "Hash": "db81b552e393ed092872cf7023469bc2" - }, - "mime": { - "Package": "mime", - "Version": "0.12", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "tools" - ], - "Hash": "18e9c28c1d3ca1560ce30658b22ce104" - }, - "multcomp": { - "Package": "multcomp", - "Version": "1.4-25", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "TH.data", - "codetools", - "graphics", - "mvtnorm", - "sandwich", - "stats", - "survival" - ], - "Hash": "2688bf2f8d54c19534ee7d8a876d9fc7" - }, - "munsell": { - "Package": "munsell", - "Version": "0.5.0", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "colorspace", - "methods" - ], - "Hash": "6dfe8bf774944bd5595785e3229d8771" - }, - "mvtnorm": { - "Package": "mvtnorm", - "Version": "1.2-4", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "stats" - ], - "Hash": "17e96668f44a28aef0981d9e17c49b59" - }, - "nlme": { - "Package": "nlme", - "Version": "3.1-160", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "graphics", - "lattice", - "stats", - "utils" - ], - "Hash": "02e3c6e7df163aafa8477225e6827bc5" - }, - "nnet": { - "Package": "nnet", - "Version": "7.3-18", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "stats", - "utils" - ], - "Hash": "170da2130d5332bea7d6ede01875ba1d" - }, - "numDeriv": { - "Package": "numDeriv", - "Version": "2016.8-1.1", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R" - ], - "Hash": "df58958f293b166e4ab885ebcad90e02" - }, - "openssl": { - "Package": "openssl", - "Version": "2.1.1", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "askpass" - ], - "Hash": "2a0dc8c6adfb6f032e4d4af82d258ab5" - }, - "parallelly": { - "Package": "parallelly", - "Version": "1.36.0", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "parallel", - "tools", - "utils" - ], - "Hash": "bca377e1c87ec89ebed77bba00635b2e" - }, - "pec": { - "Package": "pec", - "Version": "2023.04.12", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "foreach", - "lava", - "prodlim", - "riskRegression", - "rms", - "survival", - "timereg" - ], - "Hash": "4c8faf404f9d6926b6108c123c2b79b3" - }, - "penAFT": { - "Package": "penAFT", - "Version": "0.3.0", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "Matrix", - "Rcpp", - "RcppArmadillo", - "ggplot2", - "irlba" - ], - "Hash": "2a896128b68516fd236e584dd44b0f1e" - }, - "pillar": { - "Package": "pillar", - "Version": "1.9.0", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "cli", - "fansi", - "glue", - "lifecycle", - "rlang", - "utf8", - "utils", - "vctrs" - ], - "Hash": "15da5a8412f317beeee6175fbc76f4bb" - }, - "pkgconfig": { - "Package": "pkgconfig", - "Version": "2.0.3", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "utils" - ], - "Hash": "01f28d4278f15c76cddbea05899c5d6f" - }, - "plotly": { - "Package": "plotly", - "Version": "4.10.4", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "RColorBrewer", - "base64enc", - "crosstalk", - "data.table", - "digest", - "dplyr", - "ggplot2", - "htmltools", - "htmlwidgets", - "httr", - "jsonlite", - "lazyeval", - "magrittr", - "promises", - "purrr", - "rlang", - "scales", - "tibble", - "tidyr", - "tools", - "vctrs", - "viridisLite" - ], - "Hash": "a1ac5c03ad5ad12b9d1597e00e23c3dd" - }, - "plotrix": { - "Package": "plotrix", - "Version": "3.8-4", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "grDevices", - "graphics", - "stats", - "utils" - ], - "Hash": "d47fdfc45aeba360ce9db50643de3fbd" - }, - "plyr": { - "Package": "plyr", - "Version": "1.8.9", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "Rcpp" - ], - "Hash": "6b8177fd19982f0020743fadbfdbd933" - }, - "polspline": { - "Package": "polspline", - "Version": "1.1.24", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "graphics", - "stats" - ], - "Hash": "25658353186c2763a3b0fb92c9fd8ff8" - }, - "prettyunits": { - "Package": "prettyunits", - "Version": "1.2.0", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R" - ], - "Hash": "6b01fc98b1e86c4f705ce9dcfd2f57c7" - }, - "prodlim": { - "Package": "prodlim", - "Version": "2023.08.28", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "KernSmooth", - "R", - "Rcpp", - "data.table", - "diagram", - "grDevices", - "graphics", - "lava", - "stats", - "survival" - ], - "Hash": "c73e09a2039a0f75ac0a1e5454b39993" - }, - "progress": { - "Package": "progress", - "Version": "1.2.3", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "R6", - "crayon", - "hms", - "prettyunits" - ], - "Hash": "f4625e061cb2865f111b47ff163a5ca6" - }, - "progressr": { - "Package": "progressr", - "Version": "0.14.0", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "digest", - "utils" - ], - "Hash": "ac50c4ffa8f6a46580dd4d7813add3c4" - }, - "promises": { - "Package": "promises", - "Version": "1.2.1", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R6", - "Rcpp", - "fastmap", - "later", - "magrittr", - "rlang", - "stats" - ], - "Hash": "0d8a15c9d000970ada1ab21405387dee" - }, - "purrr": { - "Package": "purrr", - "Version": "1.0.2", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "cli", - "lifecycle", - "magrittr", - "rlang", - "vctrs" - ], - "Hash": "1cba04a4e9414bdefc9dcaa99649a8dc" - }, - "quantreg": { - "Package": "quantreg", - "Version": "5.97", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "MASS", - "Matrix", - "MatrixModels", - "R", - "SparseM", - "graphics", - "methods", - "stats", - "survival" - ], - "Hash": "1bbc97f7d637ab3917c514a69047b2c1" - }, - "ranger": { - "Package": "ranger", - "Version": "0.16.0", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "Matrix", - "R", - "Rcpp", - "RcppEigen" - ], - "Hash": "d5ca3a8d00f088042ea3b638534e0f3d" - }, - "rappdirs": { - "Package": "rappdirs", - "Version": "0.3.3", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R" - ], - "Hash": "5e3c5dc0b071b21fa128676560dbe94d" - }, - "readr": { - "Package": "readr", - "Version": "2.1.5", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "R6", - "cli", - "clipr", - "cpp11", - "crayon", - "hms", - "lifecycle", - "methods", - "rlang", - "tibble", - "tzdb", - "utils", - "vroom" - ], - "Hash": "9de96463d2117f6ac49980577939dfb3" - }, - "readxl": { - "Package": "readxl", - "Version": "1.4.3", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "cellranger", - "cpp11", - "progress", - "tibble", - "utils" - ], - "Hash": "8cf9c239b96df1bbb133b74aef77ad0a" - }, - "rematch": { - "Package": "rematch", - "Version": "2.0.0", - "Source": "Repository", - "Repository": "CRAN", - "Hash": "cbff1b666c6fa6d21202f07e2318d4f1" - }, - "renv": { - "Package": "renv", - "Version": "0.17.3", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "utils" - ], - "Hash": "4543b8cd233ae25c6aba8548be9e747e" - }, - "reshape2": { - "Package": "reshape2", - "Version": "1.4.4", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "Rcpp", - "plyr", - "stringr" - ], - "Hash": "bb5996d0bd962d214a11140d77589917" - }, - "riskRegression": { - "Package": "riskRegression", - "Version": "2023.12.21", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "Publish", - "R", - "Rcpp", - "RcppArmadillo", - "cmprsk", - "data.table", - "doParallel", - "foreach", - "ggplot2", - "graphics", - "lattice", - "lava", - "mets", - "mvtnorm", - "parallel", - "plotrix", - "prodlim", - "ranger", - "rms", - "stats", - "survival", - "timereg" - ], - "Hash": "98285aebfa1754a73115621ba7b2bb51" - }, - "rjson": { - "Package": "rjson", - "Version": "0.2.21", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R" - ], - "Hash": "f9da75e6444e95a1baf8ca24909d63b9" - }, - "rlang": { - "Package": "rlang", - "Version": "1.1.3", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "utils" - ], - "Hash": "42548638fae05fd9a9b5f3f437fbbbe2" - }, - "rmarkdown": { - "Package": "rmarkdown", - "Version": "2.25", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "bslib", - "evaluate", - "fontawesome", - "htmltools", - "jquerylib", - "jsonlite", - "knitr", - "methods", - "stringr", - "tinytex", - "tools", - "utils", - "xfun", - "yaml" - ], - "Hash": "d65e35823c817f09f4de424fcdfa812a" - }, - "rms": { - "Package": "rms", - "Version": "6.7-1", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "Hmisc", - "MASS", - "R", - "SparseM", - "cluster", - "colorspace", - "digest", - "ggplot2", - "grDevices", - "htmlTable", - "htmltools", - "knitr", - "methods", - "multcomp", - "nlme", - "polspline", - "quantreg", - "rpart", - "survival" - ], - "Hash": "334d50cf2e013943a4744da3b0332402" - }, - "rpart": { - "Package": "rpart", - "Version": "4.1.19", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "grDevices", - "graphics", - "stats" - ], - "Hash": "b3c892a81783376cc2204af0f5805a80" - }, - "rprojroot": { - "Package": "rprojroot", - "Version": "2.0.4", - "Source": "Repository", - "Repository": "RSPM", - "Requirements": [ - "R" - ], - "Hash": "4c8415e0ec1e29f3f4f6fc108bef0144" - }, - "rstudioapi": { - "Package": "rstudioapi", - "Version": "0.15.0", - "Source": "Repository", - "Repository": "CRAN", - "Hash": "5564500e25cffad9e22244ced1379887" - }, - "sandwich": { - "Package": "sandwich", - "Version": "3.1-0", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "stats", - "utils", - "zoo" - ], - "Hash": "1cf6ae532f0179350862fefeb0987c9b" - }, - "sass": { - "Package": "sass", - "Version": "0.4.8", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R6", - "fs", - "htmltools", - "rappdirs", - "rlang" - ], - "Hash": "168f9353c76d4c4b0a0bbf72e2c2d035" - }, - "scales": { - "Package": "scales", - "Version": "1.3.0", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "R6", - "RColorBrewer", - "cli", - "farver", - "glue", - "labeling", - "lifecycle", - "munsell", - "rlang", - "viridisLite" - ], - "Hash": "c19df082ba346b0ffa6f833e92de34d1" - }, - "shape": { - "Package": "shape", - "Version": "1.4.6", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "grDevices", - "graphics", - "stats" - ], - "Hash": "9067f962730f58b14d8ae54ca885509f" - }, - "snakecase": { - "Package": "snakecase", - "Version": "0.11.1", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "stringi", - "stringr" - ], - "Hash": "58767e44739b76965332e8a4fe3f91f1" - }, - "splitTools": { - "Package": "splitTools", - "Version": "1.0.1", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "stats" - ], - "Hash": "e17dc90796ae3e8a93fc4ac85c7f857b" - }, - "stringi": { - "Package": "stringi", - "Version": "1.8.3", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "stats", - "tools", - "utils" - ], - "Hash": "058aebddea264f4c99401515182e656a" - }, - "stringr": { - "Package": "stringr", - "Version": "1.5.1", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "cli", - "glue", - "lifecycle", - "magrittr", - "rlang", - "stringi", - "vctrs" - ], - "Hash": "960e2ae9e09656611e0b8214ad543207" - }, - "survival": { - "Package": "survival", - "Version": "3.4-0", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "Matrix", - "R", - "graphics", - "methods", - "splines", - "stats", - "utils" - ], - "Hash": "04411ae66ab4659230c067c32966fc20" - }, - "sys": { - "Package": "sys", - "Version": "3.4.2", - "Source": "Repository", - "Repository": "CRAN", - "Hash": "3a1be13d68d47a8cd0bfd74739ca1555" - }, - "tibble": { - "Package": "tibble", - "Version": "3.2.1", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "fansi", - "lifecycle", - "magrittr", - "methods", - "pillar", - "pkgconfig", - "rlang", - "utils", - "vctrs" - ], - "Hash": "a84e2cc86d07289b3b6f5069df7a004c" - }, - "tidyr": { - "Package": "tidyr", - "Version": "1.3.0", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "cli", - "cpp11", - "dplyr", - "glue", - "lifecycle", - "magrittr", - "purrr", - "rlang", - "stringr", - "tibble", - "tidyselect", - "utils", - "vctrs" - ], - "Hash": "e47debdc7ce599b070c8e78e8ac0cfcf" - }, - "tidyselect": { - "Package": "tidyselect", - "Version": "1.2.0", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "cli", - "glue", - "lifecycle", - "rlang", - "vctrs", - "withr" - ], - "Hash": "79540e5fcd9e0435af547d885f184fd5" - }, - "timechange": { - "Package": "timechange", - "Version": "0.3.0", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "cpp11" - ], - "Hash": "c5f3c201b931cd6474d17d8700ccb1c8" - }, - "timereg": { - "Package": "timereg", - "Version": "2.0.5", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "grDevices", - "graphics", - "lava", - "methods", - "numDeriv", - "stats", - "survival", - "utils" - ], - "Hash": "554d68bf30e775628a81b992e1a4876d" - }, - "tinytex": { - "Package": "tinytex", - "Version": "0.49", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "xfun" - ], - "Hash": "5ac22900ae0f386e54f1c307eca7d843" - }, - "tzdb": { - "Package": "tzdb", - "Version": "0.4.0", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "cpp11" - ], - "Hash": "f561504ec2897f4d46f0c7657e488ae1" - }, - "useful": { - "Package": "useful", - "Version": "1.2.6.1", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "Matrix", - "assertthat", - "dplyr", - "ggplot2", - "magrittr", - "plyr", - "purrr", - "scales", - "stats", - "utils" - ], - "Hash": "2e76a6230fd4a81383ee91569c04e757" - }, - "utf8": { - "Package": "utf8", - "Version": "1.2.4", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R" - ], - "Hash": "62b65c52671e6665f803ff02954446e9" - }, - "vctrs": { - "Package": "vctrs", - "Version": "0.6.5", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "cli", - "glue", - "lifecycle", - "rlang" - ], - "Hash": "c03fa420630029418f7e6da3667aac4a" - }, - "viridis": { - "Package": "viridis", - "Version": "0.6.4", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "ggplot2", - "gridExtra", - "viridisLite" - ], - "Hash": "80cd127bc8c9d3d9f0904ead9a9102f1" - }, - "viridisLite": { - "Package": "viridisLite", - "Version": "0.4.2", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R" - ], - "Hash": "c826c7c4241b6fc89ff55aaea3fa7491" - }, - "vroom": { - "Package": "vroom", - "Version": "1.6.5", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "bit64", - "cli", - "cpp11", - "crayon", - "glue", - "hms", - "lifecycle", - "methods", - "progress", - "rlang", - "stats", - "tibble", - "tidyselect", - "tzdb", - "vctrs", - "withr" - ], - "Hash": "390f9315bc0025be03012054103d227c" - }, - "withr": { - "Package": "withr", - "Version": "3.0.0", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "grDevices", - "graphics" - ], - "Hash": "d31b6c62c10dcf11ec530ca6b0dd5d35" - }, - "xfun": { - "Package": "xfun", - "Version": "0.41", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "stats", - "tools" - ], - "Hash": "460a5e0fe46a80ef87424ad216028014" - }, - "xts": { - "Package": "xts", - "Version": "0.13.1", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "methods", - "zoo" - ], - "Hash": "b8aa1235fd8b0ff10756150b792dc60f" - }, - "yaml": { - "Package": "yaml", - "Version": "2.3.8", - "Source": "Repository", - "Repository": "CRAN", - "Hash": "29240487a071f535f5e5d5a323b7afbd" - }, - "zoo": { - "Package": "zoo", - "Version": "1.8-12", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "grDevices", - "graphics", - "lattice", - "stats", - "utils" - ], - "Hash": "5c715954112b45499fb1dadc6ee6ee3e" - } - } -} diff --git a/paper/reproduce.sh b/paper/reproduce.sh deleted file mode 100644 index 3666e90..0000000 --- a/paper/reproduce.sh +++ /dev/null @@ -1,10 +0,0 @@ -#!/bin/bash - -export OMP_NUM_THREADS=1 - -bash scripts/sh/create_paths.sh &> create_paths.log -bash scripts/sh/download_data.sh &> download_data.log -bash scripts/sh/preprocess_data.sh &> preprocess_data.log -bash scripts/sh/make_splits.sh &> make_splits.log -bash scripts/sh/rerun_experiments.sh &> reproduce_experiments.log -bash scripts/sh/remake_figures_and_tables.sh &> reproduce_figures_and_tables.log diff --git a/paper/requirements.txt b/paper/requirements.txt deleted file mode 100644 index 1c310fc..0000000 --- a/paper/requirements.txt +++ /dev/null @@ -1,10 +0,0 @@ -celer==0.7.2 -scikit-learn==1.2.2 -scikit-survival==0.21.0 -pandas==1.5.3 -numba==0.57.1 -numpy==1.23.4 -pycox==0.2.3 -sparsesurv @ https://github.com/BoevaLab/sparsesurv@paper -skorch==0.15.0 -torch==2.1.2 diff --git a/paper/scripts/py/make_table_S1.py b/paper/scripts/py/make_table_S1.py deleted file mode 100644 index 43f6df9..0000000 --- a/paper/scripts/py/make_table_S1.py +++ /dev/null @@ -1,73 +0,0 @@ -import json - -import numpy as np -import pandas as pd -from sparsesurv.utils import transform_survival - -with open("./config.json") as f: - config = json.load(f) - -np.random.seed(config["random_state"]) - -cancer_type = config["datasets"] -tissue = [ - "Bladder", - "Breast", - "Head and neck", - "Kidney", - "Brain", - "Liver", - "Lung", - "Lung", - "Ovaries", - "Stomach", -] -full_name = [ - "Bladder Urothelial Carcinoma", - "Breast invasive carcinoma", - "Head and neck squamous cell carcinoma", - "Kidney renal clear cell carcinoma", - "Brain lower grade glioma", - "Liver hepatocellular carcinoma", - "Lung adenocarcinoma", - "Lung squamous cell carcinoma", - "Ovarian serous cystadenocarcinoma", - "Stomach adenocarcinoma", -] -p = [] -n = [] -event_ratio = [] -min_event_time = [] -max_event_time = [] -median_event_time = [] - - -for cancer in config["datasets"]: - print(f"Starting: {cancer}") - train_splits = pd.read_csv(f"./data/splits/TCGA/{cancer}_train_splits.csv") - test_splits = pd.read_csv(f"./data/splits/TCGA/{cancer}_test_splits.csv") - data = pd.read_csv(f"./data/processed/TCGA/{cancer}_data_preprocessed.csv").iloc[ - :, 1: - ] - X_ = data.iloc[:, 3:] - y_ = transform_survival(time=data["OS_days"].values, event=data["OS"].values) - p.append(X_.shape[1]) - n.append(X_.shape[0]) - event_ratio.append(np.mean(data["OS"].values)) - min_event_time.append(np.min(data["OS_days"].values)) - max_event_time.append(np.max(data["OS_days"].values)) - median_event_time.append(np.median(data["OS_days"].values)) - -pd.DataFrame( - { - "type": cancer_type, - "tissue": tissue, - "full_name": full_name, - "p": p, - "n": n, - "event_ratio": event_ratio, - "min_event_time": min_event_time, - "max_event_time": max_event_time, - "median_event_time": median_event_time, - } -).to_csv("./tables/table_S1.csv", index=False) diff --git a/paper/scripts/py/rerun_splits.py b/paper/scripts/py/rerun_splits.py deleted file mode 100644 index 5bcf01a..0000000 --- a/paper/scripts/py/rerun_splits.py +++ /dev/null @@ -1,56 +0,0 @@ -import argparse -import json -import os - -import numpy as np -import pandas as pd -from sklearn.model_selection import RepeatedStratifiedKFold - -parser = argparse.ArgumentParser() -parser.add_argument("--data_dir", type=str) -parser.add_argument( - "--config_path", - type=str, -) - - -def main(data_dir, config_path) -> int: - with open(f"{config_path}/config.json") as f: - config = json.load(f) - np.random.seed(config["random_state"]) - splits_path = os.path.join(data_dir, "splits", "TCGA") - os.makedirs(splits_path, exist_ok=True) - for cancer in config["datasets"]: - print(cancer) - data_path = f"processed/TCGA/{cancer}_data_preprocessed.csv" - data = pd.read_csv( - os.path.join(data_dir, data_path), - low_memory=False, - ) - - # Exact column choice doesn't matter - # as this is only to create the splits anyway. - X = data[[i for i in data.columns if i not in ["OS_days", "OS"]]] - cv = RepeatedStratifiedKFold( - n_repeats=config["n_outer_repetitions"], - n_splits=config["n_outer_splits"], - random_state=config["random_state"], - ) - splits = [i for i in cv.split(X, data["OS"])] - pd.DataFrame([i[0] for i in splits]).to_csv( - f"{splits_path}/{cancer}_train_splits.csv", - index=False, - ) - pd.DataFrame([i[1] for i in splits]).to_csv( - f"{splits_path}/{cancer}_test_splits.csv", - index=False, - ) - return 0 - - -if __name__ == "__main__": - args = parser.parse_args() - main( - args.data_dir, - args.config_path, - ) diff --git a/paper/scripts/py/run_path_sparsesurv.py b/paper/scripts/py/run_path_sparsesurv.py deleted file mode 100644 index da662f4..0000000 --- a/paper/scripts/py/run_path_sparsesurv.py +++ /dev/null @@ -1,250 +0,0 @@ -import json - -import celer -import numpy as np -import pandas as pd -import torch -from sklearn.decomposition import PCA -from sklearn.feature_selection import VarianceThreshold -from sklearn.metrics import make_scorer -from sklearn.model_selection import GridSearchCV, RandomizedSearchCV -from sklearn.pipeline import make_pipeline -from sklearn.preprocessing import StandardScaler -from skorch.callbacks import EarlyStopping -from sksurv.linear_model import CoxPHSurvivalAnalysis - -from sparsesurv.cv import KDPHElasticNetCV -from sparsesurv.loss import breslow_negative_likelihood, efron_negative_likelihood -from sparsesurv.neuralsurv.python.model.model import SKORCH_MODULE_FACTORY -from sparsesurv.neuralsurv.python.model.skorch_infra import FixSeed -from sparsesurv.neuralsurv.python.utils.factories import ( - CRITERION_FACTORY, - SKORCH_NET_FACTORY, -) -from sparsesurv.neuralsurv.python.utils.misc_utils import ( - StratifiedSkorchSurvivalSplit, - StratifiedSurvivalKFold, -) -from sparsesurv.utils import inverse_transform_survival, transform_survival - -with open("./config.json") as f: - config = json.load(f) - - -def breslow_score_wrapper(y_true, y_pred): - time, event = inverse_transform_survival(y_true) - return np.negative( - breslow_negative_likelihood( - linear_predictor=np.squeeze(y_pred), time=time, event=event - ) - ) - - -def efron_score_wrapper(y_true, y_pred): - time, event = inverse_transform_survival(y_true) - return np.negative( - efron_negative_likelihood(linear_predictor=y_pred, time=time, event=event) - ) - - -SCORE_FACTORY = {"breslow": breslow_score_wrapper, "efron": efron_score_wrapper} - -np.random.seed(config["random_state"]) -g = np.random.default_rng(config.get("random_state")) -model_pipe = make_pipeline( - VarianceThreshold(), - StandardScaler(), -) -en = celer.ElasticNet( - l1_ratio=config["l1_ratio"], - fit_intercept=False, -) - - -for tie_correction in ["breslow", "efron"]: - pc_pipe = GridSearchCV( - estimator=make_pipeline( - VarianceThreshold(), - StandardScaler(), - PCA(n_components=config[f"pc_n_components"]), - CoxPHSurvivalAnalysis(ties=tie_correction), - ), - param_grid={"pca__n_components": config["pc_n_components_tuned"]}, - n_jobs=config["n_jobs"], - scoring=make_scorer(SCORE_FACTORY[tie_correction]), - cv=StratifiedSurvivalKFold( - n_splits=config["n_inner_cv"], - shuffle=config["shuffle_cv"], - random_state=config["random_state"], - ), - ) - - for cancer in config["datasets"]: - sparsity = {} - - train_splits = pd.read_csv(f"./data/splits/TCGA/{cancer}_train_splits.csv") - test_splits = pd.read_csv(f"./data/splits/TCGA/{cancer}_test_splits.csv") - data = pd.read_csv( - f"./data/processed/TCGA/{cancer}_data_preprocessed.csv" - ).iloc[:, 1:] - X_ = data.iloc[:, 3:] - y_ = transform_survival(time=data["OS_days"].values, event=data["OS"].values) - for split in range(config["n_outer_splits"] * config["n_outer_repetitions"]): - train_ix = train_splits.iloc[split, :].dropna().to_numpy().astype(int) - test_ix = test_splits.iloc[split, :].dropna().to_numpy().astype(int) - X_train = X_.iloc[train_ix, :].copy().reset_index(drop=True).to_numpy() - y_train = y_[train_ix].copy() - y_test = y_[test_ix].copy() - X_test = X_.iloc[test_ix, :].copy().reset_index(drop=True).to_numpy() - - pc_pipe.fit(X_train, y_train) - path_results = en.path( - X=model_pipe.fit_transform(X_train), - y=pc_pipe.predict(X_train), - l1_ratio=config["l1_ratio"], - eps=config["eps"], - n_alphas=config["n_alphas"], - alphas=None, - ) - - for z in range(config["n_alphas"]): - path_coef = path_results[1][:, z] - if z == 0: - sparsity[split] = [] - sparsity[split].append(np.sum(path_coef != 0.0)) - helper = KDPHElasticNetCV( - tie_correction="efron", - seed=np.random.RandomState(config["random_state"]), - ) - helper.coef_ = path_coef - ix_sort = np.argsort(y_train["time"]) - helper.train_time_ = y_train["time"][ix_sort] - helper.train_event_ = y_train["event"][ix_sort] - helper.train_eta_ = helper.predict(model_pipe.transform(X_train))[ - ix_sort - ] - surv = helper.predict_survival_function( - model_pipe.transform(X_test), np.unique(y_test["time"]) - ) - surv.to_csv( - f"./results/kd/{tie_correction}/{cancer}/path/survival_function_{z+1}_alpha_{split+1}.csv", - index=False, - ) - - pd.DataFrame(sparsity).to_csv( - f"./results/kd/{tie_correction}/{cancer}/path/sparsity.csv", - index=False, - ) - -for cancer in config["datasets"]: - sparsity = {} - - train_splits = pd.read_csv(f"./data/splits/TCGA/{cancer}_train_splits.csv") - test_splits = pd.read_csv(f"./data/splits/TCGA/{cancer}_test_splits.csv") - data = pd.read_csv(f"./data/processed/TCGA/{cancer}_data_preprocessed.csv").iloc[ - :, 1: - ] - X_ = data.iloc[:, 3:] - y_ = transform_survival(time=data["OS_days"].values, event=data["OS"].values) - pc_pipe = RandomizedSearchCV( - estimator=make_pipeline( - StandardScaler(), - SKORCH_NET_FACTORY["cox"]( - module=SKORCH_MODULE_FACTORY["cox"], - criterion=CRITERION_FACTORY["cox"], - module__fusion_method="early", - module__blocks=[[i for i in range(X_.shape[1])]], - iterator_train__shuffle=True, - optimizer=torch.optim.AdamW, - max_epochs=config["max_epochs"], - verbose=False, - train_split=StratifiedSkorchSurvivalSplit( - config["validation_set_neural"], - stratified=config["stratify_cv"], - random_state=config.get("random_state"), - ), - callbacks=[ - ( - "es", - EarlyStopping( - monitor="valid_loss", - patience=config["early_stopping_patience"], - load_best=True, - ), - ), - ("seed", FixSeed(generator=g)), - ], - module__activation=torch.nn.ReLU, - ), - ), - param_distributions={ - "coxphneuralnet__lr": config["tune_lr"], - "coxphneuralnet__optimizer__weight_decay": config["tune_weight_decay"], - "coxphneuralnet__module__modality_hidden_layer_size": config[ - "tune_modality_hidden_layer_size" - ], - "coxphneuralnet__module__modality_hidden_layers": config[ - "tune_modality_hidden_layers" - ], - "coxphneuralnet__module__p_dropout": config["tune_p_dropout"], - "coxphneuralnet__batch_size": config["tune_batch_size"], - }, - n_jobs=config["n_jobs"], - random_state=config["random_state"], - scoring=make_scorer(breslow_score_wrapper), - cv=StratifiedSurvivalKFold( - n_splits=config["n_inner_cv"], - shuffle=config["shuffle_cv"], - random_state=config["random_state"], - ), - error_score=config["error_score"], - verbose=False, - n_iter=config["random_search_n_iter"], - ) - for split in range(config["n_outer_splits"] * config["n_outer_repetitions"]): - train_ix = train_splits.iloc[split, :].dropna().to_numpy().astype(int) - test_ix = test_splits.iloc[split, :].dropna().to_numpy().astype(int) - X_train = ( - X_.iloc[train_ix, :].copy().reset_index(drop=True).to_numpy(np.float32) - ) - y_train = y_[train_ix].copy() - y_test = y_[test_ix].copy() - X_test = X_.iloc[test_ix, :].copy().reset_index(drop=True).to_numpy(np.float32) - - pc_pipe.fit(X_train, y_train) - path_results = en.path( - X=model_pipe.fit_transform(X_train), - y=np.squeeze(pc_pipe.predict(X_train)), - l1_ratio=config["l1_ratio"], - eps=config["eps"], - n_alphas=config["n_alphas"], - alphas=None, - ) - - for z in range(config["n_alphas"]): - path_coef = path_results[1][:, z] - if z == 0: - sparsity[split] = [] - sparsity[split].append(np.sum(path_coef != 0.0)) - helper = KDPHElasticNetCV( - tie_correction="breslow", - seed=np.random.RandomState(config["random_state"]), - ) - helper.coef_ = path_coef - ix_sort = np.argsort(y_train["time"]) - helper.train_time_ = y_train["time"][ix_sort] - helper.train_event_ = y_train["event"][ix_sort] - helper.train_eta_ = helper.predict(model_pipe.transform(X_train))[ix_sort] - - surv = helper.predict_survival_function( - model_pipe.transform(X_test), np.unique(y_test["time"]) - ) - surv.to_csv( - f"./results/kd/cox_nnet/{cancer}/path/survival_function_{z+1}_alpha_{split+1}.csv", - index=False, - ) - - pd.DataFrame(sparsity).to_csv( - f"./results/kd/cox_nnet/{cancer}/path/sparsity.csv", - index=False, - ) diff --git a/paper/scripts/py/run_sparsesurv.py b/paper/scripts/py/run_sparsesurv.py deleted file mode 100644 index e6a88b9..0000000 --- a/paper/scripts/py/run_sparsesurv.py +++ /dev/null @@ -1,329 +0,0 @@ -import json - -import numpy as np -import pandas as pd -import torch -from sklearn.decomposition import PCA -from sklearn.feature_selection import VarianceThreshold -from sklearn.metrics import make_scorer -from sklearn.model_selection import GridSearchCV, RandomizedSearchCV -from sklearn.pipeline import make_pipeline -from sklearn.preprocessing import StandardScaler -from skorch.callbacks import EarlyStopping -from sksurv.linear_model import CoxPHSurvivalAnalysis - -from sparsesurv._base import KDSurv -from sparsesurv.cv import KDPHElasticNetCV -from sparsesurv.loss import breslow_negative_likelihood, efron_negative_likelihood -from sparsesurv.neuralsurv.python.model.model import SKORCH_MODULE_FACTORY -from sparsesurv.neuralsurv.python.model.skorch_infra import FixSeed -from sparsesurv.neuralsurv.python.utils.factories import ( - CRITERION_FACTORY, - SKORCH_NET_FACTORY, -) -from sparsesurv.neuralsurv.python.utils.misc_utils import ( - StratifiedSkorchSurvivalSplit, - StratifiedSurvivalKFold, -) -from sparsesurv.utils import inverse_transform_survival, transform_survival - -with open("./config.json") as f: - config = json.load(f) - - -def breslow_score_wrapper(y_true, y_pred): - time, event = inverse_transform_survival(y_true) - return np.negative( - breslow_negative_likelihood( - linear_predictor=np.squeeze(y_pred), time=time, event=event - ) - ) - - -def efron_score_wrapper(y_true, y_pred): - time, event = inverse_transform_survival(y_true) - return np.negative( - efron_negative_likelihood( - linear_predictor=np.squeeze(y_pred), time=time, event=event - ) - ) - - -np.random.seed(config["random_state"]) -g = np.random.default_rng(config.get("random_state")) - -for tune_l1_ratio in [False]: - for tie_correction in ["breslow"]: - for score_type in ["min", "pcvl"]: - for score in ["linear_predictor"]: - results = {} - failures = {} - sparsity = {} - pipe = KDSurv( - teacher=GridSearchCV( - estimator=make_pipeline( - VarianceThreshold(), - StandardScaler(), - PCA(n_components=config["pc_n_components_tuned"]), - CoxPHSurvivalAnalysis(ties=tie_correction), - ), - param_grid={ - "pca__n_components": config["pc_n_components_tuned"] - }, - n_jobs=config["n_jobs"], - scoring=make_scorer(efron_score_wrapper), - cv=StratifiedSurvivalKFold( - n_splits=config["n_inner_cv"], - shuffle=config["shuffle_cv"], - random_state=config["random_state"], - ), - ), - student=make_pipeline( - VarianceThreshold(), - StandardScaler(), - KDPHElasticNetCV( - tie_correction=tie_correction, - l1_ratio=config[ - f"l1_ratio{'_tuned' if tune_l1_ratio else ''}" - ], - eps=config["eps"], - n_alphas=config["n_alphas"], - cv=config["n_inner_cv"], - stratify_cv=config["stratify_cv"], - seed=np.random.RandomState(config["random_state"]), - shuffle_cv=config["shuffle_cv"], - n_jobs=config["n_jobs"], - cv_score_method=score, - alpha_type=score_type, - ), - ), - ) - - for cancer in config["datasets"]: - - train_splits = pd.read_csv( - f"./data/splits/TCGA/{cancer}_train_splits.csv" - ) - test_splits = pd.read_csv( - f"./data/splits/TCGA/{cancer}_test_splits.csv" - ) - data = pd.read_csv( - f"./data/processed/TCGA/{cancer}_data_preprocessed.csv" - ).iloc[:, 1:] - X_ = data.iloc[:, 3:] - y_ = transform_survival( - time=data["OS_days"].values, event=data["OS"].values - ) - for split in range(25): - train_ix = ( - train_splits.iloc[split, :].dropna().to_numpy().astype(int) - ) - test_ix = ( - test_splits.iloc[split, :].dropna().to_numpy().astype(int) - ) - X_train = ( - X_.iloc[train_ix, :] - .copy() - .reset_index(drop=True) - .to_numpy() - ) - y_train = y_[train_ix].copy() - y_test = y_[test_ix].copy() - X_test = ( - X_.iloc[test_ix, :].copy().reset_index(drop=True).to_numpy() - ) - if split == 0: - results[cancer] = {} - sparsity[cancer] = {} - failures[cancer] = [0] - try: - pipe.fit(X_train, y_train) - sparsity[cancer][split] = np.sum( - pipe.student[-1].coef_ != 0 - ) - results[cancer][split] = pipe.predict(X_test) - surv = pipe.predict_survival_function( - X_test, np.unique(y_test["time"]) - ) - surv.to_csv( - f"./results/kd/{tie_correction}/{cancer}/survival_function{'_tuned_l1' if tune_l1_ratio else ''}_{score}_{score_type}_{split+1}.csv", - index=False, - ) - except ValueError as e: - failures[cancer][0] += 1 - results[cancer][split] = np.zeros(test_ix.shape[0]) - sparsity[cancer][split] = 0 - - pd.concat( - [pd.DataFrame(results[cancer][i]) for i in range(25)], - axis=1, - ).to_csv( - f"./results/kd/{tie_correction}/{cancer}/eta{'_tuned_l1' if tune_l1_ratio else ''}_{score}_{score_type}.csv", - index=False, - ) - - pd.DataFrame(sparsity).to_csv( - f"./results/kd/{tie_correction}/sparsity{'_tuned_l1' if tune_l1_ratio else ''}_{score}_{score_type}.csv", - index=False, - ) - pd.DataFrame(failures).to_csv( - f"./results/kd/{tie_correction}/failures{'_tuned_l1' if tune_l1_ratio else ''}_{score}_{score_type}.csv", - index=False, - ) - -for tune_l1_ratio in []: - for score_type in ["min"]: - results = {} - failures = {} - sparsity = {} - for cancer in config["datasets"]: - for score in ["linear_predictor"]: - - train_splits = pd.read_csv( - f"./data/splits/TCGA/{cancer}_train_splits.csv" - ) - test_splits = pd.read_csv( - f"./data/splits/TCGA/{cancer}_test_splits.csv" - ) - data = pd.read_csv( - f"./data/processed/TCGA/{cancer}_data_preprocessed.csv" - ).iloc[:, 1:] - X_ = data.iloc[:, 3:] - y_ = transform_survival( - time=data["OS_days"].values, event=data["OS"].values - ) - pipe = KDSurv( - teacher=RandomizedSearchCV( - estimator=make_pipeline( - StandardScaler(), - SKORCH_NET_FACTORY["cox"]( - module=SKORCH_MODULE_FACTORY["cox"], - criterion=CRITERION_FACTORY["cox"], - module__fusion_method="early", - module__blocks=[[i for i in range(X_.shape[1])]], - iterator_train__shuffle=True, - optimizer=torch.optim.AdamW, - max_epochs=config["max_epochs"], - verbose=False, - train_split=StratifiedSkorchSurvivalSplit( - config["validation_set_neural"], - stratified=config["stratify_cv"], - random_state=config.get("random_state"), - ), - callbacks=[ - ( - "es", - EarlyStopping( - monitor="valid_loss", - patience=config["early_stopping_patience"], - load_best=True, - ), - ), - ("seed", FixSeed(generator=g)), - ], - module__activation=torch.nn.ReLU, - ), - ), - param_distributions=[ - { - "coxphneuralnet__lr": config["tune_lr"], - "coxphneuralnet__optimizer__weight_decay": config[ - "tune_weight_decay" - ], - "coxphneuralnet__module__modality_hidden_layer_size": config[ - "tune_modality_hidden_layer_size" - ], - "coxphneuralnet__module__modality_hidden_layers": config[ - "tune_modality_hidden_layers" - ], - "coxphneuralnet__module__p_dropout": config[ - "tune_p_dropout" - ], - "coxphneuralnet__batch_size": config["tune_batch_size"], - } - ], - n_jobs=config["n_jobs"], - scoring=make_scorer(breslow_score_wrapper), - cv=StratifiedSurvivalKFold( - n_splits=config["n_inner_cv"], - shuffle=config["shuffle_cv"], - random_state=config["random_state"], - ), - error_score=config["error_score"], - verbose=False, - n_iter=config["random_search_n_iter"], - random_state=config["random_state"], - ), - student=make_pipeline( - VarianceThreshold(), - StandardScaler(), - KDPHElasticNetCV( - tie_correction="breslow", - l1_ratio=config[ - f"l1_ratio{'_tuned' if tune_l1_ratio else ''}" - ], - eps=config["eps"], - n_alphas=config["n_alphas"], - cv=config["n_inner_cv"], - stratify_cv=config["stratify_cv"], - seed=np.random.RandomState(config["random_state"]), - shuffle_cv=config["shuffle_cv"], - cv_score_method=score, - n_jobs=config["n_jobs"], - alpha_type=score_type, - ), - ), - ) - for split in range(25): - train_ix = ( - train_splits.iloc[split, :].dropna().to_numpy().astype(int) - ) - test_ix = test_splits.iloc[split, :].dropna().to_numpy().astype(int) - X_train = ( - X_.iloc[train_ix, :] - .copy() - .reset_index(drop=True) - .to_numpy(np.float32) - ) - y_train = y_[train_ix].copy() - y_test = y_[test_ix].copy() - X_test = ( - X_.iloc[test_ix, :] - .copy() - .reset_index(drop=True) - .to_numpy(np.float32) - ) - if split == 0: - results[cancer] = {} - sparsity[cancer] = {} - failures[cancer] = [0] - try: - pipe.fit(X_train, y_train) - sparsity[cancer][split] = np.sum(pipe.student[-1].coef_ != 0) - results[cancer][split] = pipe.predict(X_test) - - surv = pipe.predict_survival_function( - X_test.astype(np.float32), np.unique(y_test["time"]) - ) - - surv.to_csv( - f"./results/kd/cox_nnet/{cancer}/survival_function{'_tuned_l1' if tune_l1_ratio else ''}_{score}_{score_type}_{split+1}.csv", - index=False, - ) - except ValueError as e: - raise e - pd.concat( - [pd.DataFrame(results[cancer][i]) for i in range(25)], - axis=1, - ).to_csv( - f"./results/kd/cox_nnet/{cancer}/eta{'_tuned_l1' if tune_l1_ratio else ''}_{score}_{score_type}.csv", - index=False, - ) - pd.DataFrame(sparsity).to_csv( - f"./results/kd/cox_nnet/sparsity{'_tuned_l1' if tune_l1_ratio else ''}_{score}_{score_type}.csv", - index=False, - ) - pd.DataFrame(failures).to_csv( - f"./results/kd/cox_nnet/failures{'_tuned_l1' if tune_l1_ratio else ''}_{score}_{score_type}.csv", - index=False, - ) diff --git a/paper/scripts/py/run_teachers.py b/paper/scripts/py/run_teachers.py deleted file mode 100644 index 3c534d4..0000000 --- a/paper/scripts/py/run_teachers.py +++ /dev/null @@ -1,414 +0,0 @@ -import json - -import numpy as np -import pandas as pd -import torch -from sklearn.decomposition import PCA -from sklearn.feature_selection import VarianceThreshold -from sklearn.metrics import make_scorer -from sklearn.model_selection import GridSearchCV, RandomizedSearchCV -from sklearn.pipeline import make_pipeline -from sklearn.preprocessing import StandardScaler -from skorch.callbacks import EarlyStopping -from sksurv.linear_model import CoxPHSurvivalAnalysis -from sksurv.linear_model.coxph import BreslowEstimator - -from sparsesurv.loss import breslow_negative_likelihood, efron_negative_likelihood -from sparsesurv.neuralsurv.python.model.model import SKORCH_MODULE_FACTORY -from sparsesurv.neuralsurv.python.model.skorch_infra import FixSeed -from sparsesurv.neuralsurv.python.utils.factories import ( - CRITERION_FACTORY, - SKORCH_NET_FACTORY, -) -from sparsesurv.neuralsurv.python.utils.misc_utils import ( - StratifiedSkorchSurvivalSplit, - StratifiedSurvivalKFold, -) -from sparsesurv.utils import inverse_transform_survival, transform_survival - - -def breslow_score_wrapper(y_true, y_pred): - time, event = inverse_transform_survival(y_true) - return np.negative( - breslow_negative_likelihood( - linear_predictor=np.squeeze(y_pred), time=time, event=event - ) - ) - - -def efron_score_wrapper(y_true, y_pred): - time, event = inverse_transform_survival(y_true) - return np.negative( - efron_negative_likelihood(linear_predictor=y_pred, time=time, event=event) - ) - - -with open(f"./config.json") as f: - config = json.load(f) - -np.random.seed(config["random_state"]) -g = np.random.default_rng(config.get("random_state")) - - -results = {} -for tie_correction in ["breslow"]: - teacher = GridSearchCV( - estimator=make_pipeline( - VarianceThreshold(), - StandardScaler(), - PCA(n_components=config["pc_n_components"]), - CoxPHSurvivalAnalysis(ties=tie_correction), - ), - param_grid={"pca__n_components": config["pc_n_components_tuned"]}, - n_jobs=config["n_jobs"], - scoring=make_scorer(efron_score_wrapper), - cv=StratifiedSurvivalKFold( - n_splits=config["n_inner_cv"], - shuffle=config["shuffle_cv"], - random_state=config["random_state"], - ), - ) - - for cancer in config["datasets"]: - results[cancer] = {} - train_splits = pd.read_csv(f"./data/splits/TCGA/{cancer}_train_splits.csv") - test_splits = pd.read_csv(f"./data/splits/TCGA/{cancer}_test_splits.csv") - data = pd.read_csv( - f"./data/processed/TCGA/{cancer}_data_preprocessed.csv" - ).iloc[:, 1:] - X_ = data.iloc[:, 3:] - y_ = transform_survival(time=data["OS_days"].values, event=data["OS"].values) - for split in range(25): - train_ix = train_splits.iloc[split, :].dropna().to_numpy().astype(int) - test_ix = test_splits.iloc[split, :].dropna().to_numpy().astype(int) - X_train = X_.iloc[train_ix, :].copy().reset_index(drop=True).to_numpy() - y_train = y_[train_ix].copy() - y_test = y_[test_ix].copy() - X_test = X_.iloc[test_ix, :].copy().reset_index(drop=True).to_numpy() - - teacher.fit(X_train, y_train) - results[cancer][split] = teacher.predict(X_test) - ( - cumulative_baseline_hazards_times, - cumulative_baseline_hazards, - ) = ( - teacher.best_estimator_[3].cum_baseline_hazard_.x, - teacher.best_estimator_[3].cum_baseline_hazard_.y, - ) - cumulative_baseline_hazards = np.concatenate( - [np.array([0.0]), cumulative_baseline_hazards] - ) - cumulative_baseline_hazards_times: np.array = np.concatenate( - [np.array([0.0]), cumulative_baseline_hazards_times] - ) - cumulative_baseline_hazards: np.array = np.tile( - A=cumulative_baseline_hazards[ - np.digitize( - x=np.unique(y_test["time"]), - bins=cumulative_baseline_hazards_times, - right=False, - ) - - 1 - ], - reps=X_test.shape[0], - ).reshape((X_test.shape[0], np.unique(y_test["time"]).shape[0])) - log_hazards: np.array = ( - np.tile( - A=teacher.predict(X_test), - reps=np.unique(y_test["time"]).shape[0], - ) - .reshape((np.unique(y_test["time"]).shape[0], X_test.shape[0])) - .T - ) - surv: pd.DataFrame = np.exp( - -pd.DataFrame( - cumulative_baseline_hazards * np.exp(log_hazards), - columns=np.unique(y_test["time"]), - ) - ) - surv.to_csv( - f"./results/kd/{tie_correction}/{cancer}/survival_function_teacher_{split+1}.csv", - index=False, - ) - pd.concat( - [pd.DataFrame(results[cancer][i]) for i in range(25)], - axis=1, - ).to_csv( - f"./results/kd/{tie_correction}/{cancer}/eta_teacher.csv", - index=False, - ) - -results = {} -for cancer in config["datasets"]: - results[cancer] = {} - train_splits = pd.read_csv(f"./data/splits/TCGA/{cancer}_train_splits.csv") - test_splits = pd.read_csv(f"./data/splits/TCGA/{cancer}_test_splits.csv") - data = pd.read_csv(f"./data/processed/TCGA/{cancer}_data_preprocessed.csv").iloc[ - :, 1: - ] - X_ = data.iloc[:, 3:] - y_ = transform_survival(time=data["OS_days"].values, event=data["OS"].values) - teacher_cox_nnet = RandomizedSearchCV( - estimator=make_pipeline( - StandardScaler(), - SKORCH_NET_FACTORY["cox"]( - module=SKORCH_MODULE_FACTORY["cox"], - criterion=CRITERION_FACTORY["cox"], - module__fusion_method="early", - module__blocks=[[i for i in range(X_.shape[1])]], - iterator_train__shuffle=True, - optimizer=torch.optim.AdamW, - max_epochs=config["max_epochs"], - verbose=False, - train_split=StratifiedSkorchSurvivalSplit( - config["validation_set_neural"], - stratified=config["stratify_cv"], - random_state=config.get("random_state"), - ), - callbacks=[ - ( - "es", - EarlyStopping( - monitor="valid_loss", - patience=config["early_stopping_patience"], - load_best=True, - ), - ), - ("seed", FixSeed(generator=g)), - ], - module__activation=torch.nn.ReLU, - ), - ), - param_distributions={ - "coxphneuralnet__lr": config["tune_lr"], - "coxphneuralnet__optimizer__weight_decay": config["tune_weight_decay"], - "coxphneuralnet__module__modality_hidden_layer_size": config[ - "tune_modality_hidden_layer_size" - ], - "coxphneuralnet__module__modality_hidden_layers": config[ - "tune_modality_hidden_layers" - ], - "coxphneuralnet__module__p_dropout": config["tune_p_dropout"], - "coxphneuralnet__batch_size": config["tune_batch_size"], - }, - n_jobs=config["n_jobs"], - random_state=config["random_state"], - scoring=make_scorer(breslow_score_wrapper), - cv=StratifiedSurvivalKFold( - n_splits=config["n_inner_cv"], - shuffle=config["shuffle_cv"], - random_state=config["random_state"], - ), - error_score=config["error_score"], - verbose=False, - n_iter=config["random_search_n_iter"], - ) - for split in range(25): - train_ix = train_splits.iloc[split, :].dropna().to_numpy().astype(int) - test_ix = test_splits.iloc[split, :].dropna().to_numpy().astype(int) - X_train = ( - X_.iloc[train_ix, :].copy().reset_index(drop=True).to_numpy(np.float32) - ) - y_train = y_[train_ix].copy() - y_test = y_[test_ix].copy() - X_test = X_.iloc[test_ix, :].copy().reset_index(drop=True).to_numpy(np.float32) - - teacher_cox_nnet.fit(X_train, y_train) - results[cancer][split] = teacher_cox_nnet.predict(X_test) - breslow = BreslowEstimator() - breslow.fit( - linear_predictor=teacher_cox_nnet.predict(X_train), - time=data["OS_days"][train_ix].values, - event=data["OS"][train_ix].values, - ) - ( - cumulative_baseline_hazards_times, - cumulative_baseline_hazards, - ) = ( - breslow.cum_baseline_hazard_.x, - breslow.cum_baseline_hazard_.y, - ) - cumulative_baseline_hazards = np.concatenate( - [np.array([0.0]), cumulative_baseline_hazards] - ) - cumulative_baseline_hazards_times: np.array = np.concatenate( - [np.array([0.0]), cumulative_baseline_hazards_times] - ) - cumulative_baseline_hazards: np.array = np.tile( - A=cumulative_baseline_hazards[ - np.digitize( - x=np.unique(y_test["time"]), - bins=cumulative_baseline_hazards_times, - right=False, - ) - - 1 - ], - reps=X_test.shape[0], - ).reshape((X_test.shape[0], np.unique(y_test["time"]).shape[0])) - log_hazards: np.array = ( - np.tile( - A=teacher_cox_nnet.predict(X_test).squeeze(), - reps=np.unique(y_test["time"]).shape[0], - ) - .reshape((np.unique(y_test["time"]).shape[0], X_test.shape[0])) - .T - ) - surv: pd.DataFrame = np.exp( - -pd.DataFrame( - cumulative_baseline_hazards * np.exp(log_hazards), - columns=np.unique(y_test["time"]), - ) - ) - surv.to_csv( - f"./results/kd/cox_nnet/{cancer}/survival_function_teacher_{split+1}.csv", - index=False, - ) - - pd.concat( - [pd.DataFrame(results[cancer][i]) for i in range(25)], - axis=1, - ).to_csv( - f"./results/kd/cox_nnet/{cancer}/eta_teacher.csv", - index=False, - ) - -results = {} -for cancer in config["datasets"]: - results[cancer] = {} - train_splits = pd.read_csv(f"./data/splits/TCGA/{cancer}_train_splits.csv") - test_splits = pd.read_csv(f"./data/splits/TCGA/{cancer}_test_splits.csv") - data = pd.read_csv(f"./data/processed/TCGA/{cancer}_data_preprocessed.csv").iloc[ - :, 1: - ] - X_ = data.iloc[:, 3:] - y_ = transform_survival(time=data["OS_days"].values, event=data["OS"].values) - teacher_cox_nnet = RandomizedSearchCV( - estimator=make_pipeline( - VarianceThreshold(), - StandardScaler(), - PCA(n_components=config["pc_n_components"]), - SKORCH_NET_FACTORY["cox"]( - module=SKORCH_MODULE_FACTORY["cox"], - criterion=CRITERION_FACTORY["cox"], - module__fusion_method="early", - module__blocks=[], - iterator_train__shuffle=True, - optimizer=torch.optim.AdamW, - max_epochs=config["max_epochs"], - verbose=False, - train_split=StratifiedSkorchSurvivalSplit( - config["validation_set_neural"], - stratified=config["stratify_cv"], - random_state=config.get("random_state"), - ), - callbacks=[ - ( - "es", - EarlyStopping( - monitor="valid_loss", - patience=config["early_stopping_patience"], - load_best=True, - ), - ), - ("seed", FixSeed(generator=g)), - ], - module__activation=torch.nn.ReLU, - ), - ), - param_distributions=[ - { - "pca__n_components": [pc_n_dim], - "coxphneuralnet__module__blocks": [[[i for i in range(pc_n_dim)]]], - "coxphneuralnet__lr": config["tune_lr"], - "coxphneuralnet__optimizer__weight_decay": config["tune_weight_decay"], - "coxphneuralnet__module__modality_hidden_layer_size": config[ - "tune_modality_hidden_layer_size" - ], - "coxphneuralnet__module__modality_hidden_layers": config[ - "tune_modality_hidden_layers" - ], - "coxphneuralnet__module__p_dropout": config["tune_p_dropout"], - "coxphneuralnet__batch_size": config["tune_batch_size"], - } - for pc_n_dim in config["pc_n_components_tuned"] - ], - n_jobs=config["n_jobs"], - random_state=config["random_state"], - scoring=make_scorer(breslow_score_wrapper), - cv=StratifiedSurvivalKFold( - n_splits=config["n_inner_cv"], - shuffle=config["shuffle_cv"], - random_state=config["random_state"], - ), - error_score=config["error_score"], - verbose=False, - n_iter=config["random_search_n_iter"], - ) - for split in range(25): - train_ix = train_splits.iloc[split, :].dropna().to_numpy().astype(int) - test_ix = test_splits.iloc[split, :].dropna().to_numpy().astype(int) - X_train = ( - X_.iloc[train_ix, :].copy().reset_index(drop=True).to_numpy(np.float32) - ) - y_train = y_[train_ix].copy() - y_test = y_[test_ix].copy() - X_test = X_.iloc[test_ix, :].copy().reset_index(drop=True).to_numpy(np.float32) - - teacher_cox_nnet.fit(X_train, y_train) - results[cancer][split] = teacher_cox_nnet.predict(X_test) - breslow = BreslowEstimator() - breslow.fit( - linear_predictor=teacher_cox_nnet.predict(X_train), - time=data["OS_days"][train_ix].values, - event=data["OS"][train_ix].values, - ) - ( - cumulative_baseline_hazards_times, - cumulative_baseline_hazards, - ) = ( - breslow.cum_baseline_hazard_.x, - breslow.cum_baseline_hazard_.y, - ) - cumulative_baseline_hazards = np.concatenate( - [np.array([0.0]), cumulative_baseline_hazards] - ) - cumulative_baseline_hazards_times: np.array = np.concatenate( - [np.array([0.0]), cumulative_baseline_hazards_times] - ) - cumulative_baseline_hazards: np.array = np.tile( - A=cumulative_baseline_hazards[ - np.digitize( - x=np.unique(y_test["time"]), - bins=cumulative_baseline_hazards_times, - right=False, - ) - - 1 - ], - reps=X_test.shape[0], - ).reshape((X_test.shape[0], np.unique(y_test["time"]).shape[0])) - log_hazards: np.array = ( - np.tile( - A=teacher_cox_nnet.predict(X_test).squeeze(), - reps=np.unique(y_test["time"]).shape[0], - ) - .reshape((np.unique(y_test["time"]).shape[0], X_test.shape[0])) - .T - ) - surv: pd.DataFrame = np.exp( - -pd.DataFrame( - cumulative_baseline_hazards * np.exp(log_hazards), - columns=np.unique(y_test["time"]), - ) - ) - surv.to_csv( - f"./results/kd/cox_nnet/{cancer}/survival_function_teacher_{split+1}.csv", - index=False, - ) - - pd.concat( - [pd.DataFrame(results[cancer][i]) for i in range(25)], - axis=1, - ).to_csv( - f"./results/kd/cox_nnet/{cancer}/eta_teacher.csv", - index=False, - ) diff --git a/paper/scripts/py/time_sparsesurv.py b/paper/scripts/py/time_sparsesurv.py deleted file mode 100644 index 43cda60..0000000 --- a/paper/scripts/py/time_sparsesurv.py +++ /dev/null @@ -1,239 +0,0 @@ -import json -from timeit import default_timer as timer - -import numpy as np -import pandas as pd -import torch -from sklearn.decomposition import PCA -from sklearn.feature_selection import VarianceThreshold -from sklearn.metrics import make_scorer -from sklearn.model_selection import GridSearchCV, RandomizedSearchCV -from sklearn.pipeline import make_pipeline -from sklearn.preprocessing import StandardScaler -from skorch.callbacks import EarlyStopping -from sksurv.linear_model import CoxPHSurvivalAnalysis - -from sparsesurv._base import KDSurv -from sparsesurv.cv import KDPHElasticNetCV -from sparsesurv.loss import breslow_negative_likelihood, efron_negative_likelihood -from sparsesurv.neuralsurv.python.model.model import SKORCH_MODULE_FACTORY -from sparsesurv.neuralsurv.python.model.skorch_infra import FixSeed -from sparsesurv.neuralsurv.python.utils.factories import ( - CRITERION_FACTORY, - SKORCH_NET_FACTORY, -) -from sparsesurv.neuralsurv.python.utils.misc_utils import ( - StratifiedSkorchSurvivalSplit, - StratifiedSurvivalKFold, -) -from sparsesurv.utils import inverse_transform_survival, transform_survival - -with open("./config.json") as f: - config = json.load(f) - - -def breslow_score_wrapper(y_true, y_pred): - time, event = inverse_transform_survival(y_true) - return np.negative( - breslow_negative_likelihood( - linear_predictor=np.squeeze(y_pred.astype(np.float64)), - time=time, - event=event, - ) - ) - - -def efron_score_wrapper(y_true, y_pred): - time, event = inverse_transform_survival(y_true) - return np.negative( - efron_negative_likelihood(linear_predictor=y_pred, time=time, event=event) - ) - - -SCORE_FACTORY = {"breslow": breslow_score_wrapper, "efron": efron_score_wrapper} - - -g = np.random.default_rng(config.get("random_state")) -np.random.seed(config["random_state"]) - - -for tune_teacher in [True]: - for tune_l1_ratio in [False]: - for tie_correction in ["breslow"]: - timing = {} - for cancer in config["datasets"]: - timing[cancer] = [] - print(f"Starting: {cancer}") - train_splits = pd.read_csv( - f"./data/splits/TCGA/{cancer}_train_splits.csv" - ) - test_splits = pd.read_csv( - f"./data/splits/TCGA/{cancer}_test_splits.csv" - ) - data = pd.read_csv( - f"./data/processed/TCGA/{cancer}_data_preprocessed.csv" - ).iloc[:, 1:] - X_ = data.iloc[:, 3:].to_numpy() - y_ = transform_survival( - time=data["OS_days"].values, event=data["OS"].values - ) - for rep in range(config["timing_reps"]): - pipe = KDSurv( - teacher=GridSearchCV( - estimator=make_pipeline( - StandardScaler(), - PCA(n_components=config["pc_n_components"]), - CoxPHSurvivalAnalysis(ties=tie_correction), - ), - param_grid={ - "pca__n_components": config[ - f"pc_n_components{'_tuned' if tune_teacher else ''}" - ] - }, - n_jobs=1, - verbose=0, - scoring=make_scorer(SCORE_FACTORY[tie_correction]), - cv=StratifiedSurvivalKFold( - n_splits=config["n_inner_cv"], - shuffle=config["shuffle_cv"], - random_state=config["random_state"], - ), - ), - student=make_pipeline( - VarianceThreshold(), - StandardScaler(), - KDPHElasticNetCV( - tie_correction=tie_correction, - l1_ratio=config[ - f"l1_ratio{'_tuned' if tune_l1_ratio else ''}" - ], - eps=config["eps"], - n_alphas=config["n_alphas"], - cv=config["n_inner_cv"], - stratify_cv=config["stratify_cv"], - seed=np.random.RandomState(config["random_state"]), - shuffle_cv=config["shuffle_cv"], - cv_score_method="linear_predictor", - n_jobs=1, - ), - ), - ) - start = timer() - pipe.fit(X_, y_) - end = timer() - timing[cancer].append(end - start) - if tune_l1_ratio: - pd.DataFrame(timing).to_csv( - f"./results/kd/{tie_correction}/timing_tuned_l1_ratio{'_tuned_teacher' if tune_teacher else '' }.csv", - index=False, - ) - else: - pd.DataFrame(timing).to_csv( - f"./results/kd/{tie_correction}/timing{'_tuned_teacher' if tune_teacher else '' }.csv", - index=False, - ) - -for tune_l1_ratio in []: - timing = {} - for cancer in config["datasets"]: - timing[cancer] = [] - print(f"Starting: {cancer}") - train_splits = pd.read_csv(f"./data/splits/TCGA/{cancer}_train_splits.csv") - test_splits = pd.read_csv(f"./data/splits/TCGA/{cancer}_test_splits.csv") - data = pd.read_csv( - f"./data/processed/TCGA/{cancer}_data_preprocessed.csv" - ).iloc[:, 1:] - X_ = data.iloc[:, 3:].to_numpy().astype(np.float32) - y_ = transform_survival(time=data["OS_days"].values, event=data["OS"].values) - for rep in range(config["timing_reps"]): - pipe = KDSurv( - teacher=RandomizedSearchCV( - estimator=make_pipeline( - StandardScaler(), - SKORCH_NET_FACTORY["cox"]( - module=SKORCH_MODULE_FACTORY["cox"], - criterion=CRITERION_FACTORY["cox"], - module__fusion_method="early", - module__blocks=[[i for i in range(X_.shape[1])]], - iterator_train__shuffle=True, - optimizer=torch.optim.AdamW, - max_epochs=config["max_epochs"], - verbose=False, - train_split=StratifiedSkorchSurvivalSplit( - config["validation_set_neural"], - stratified=config["stratify_cv"], - random_state=config.get("random_state"), - ), - callbacks=[ - ( - "es", - EarlyStopping( - monitor="valid_loss", - patience=config["early_stopping_patience"], - load_best=True, - ), - ), - ("seed", FixSeed(generator=g)), - ], - module__activation=torch.nn.ReLU, - ), - ), - param_distributions={ - "coxphneuralnet__lr": config["tune_lr"], - "coxphneuralnet__optimizer__weight_decay": config[ - "tune_weight_decay" - ], - "coxphneuralnet__module__modality_hidden_layer_size": config[ - "tune_modality_hidden_layer_size" - ], - "coxphneuralnet__module__modality_hidden_layers": config[ - "tune_modality_hidden_layers" - ], - "coxphneuralnet__module__p_dropout": config["tune_p_dropout"], - "coxphneuralnet__batch_size": config["tune_batch_size"], - }, - n_jobs=config["n_jobs"], - random_state=config["random_state"], - scoring=make_scorer(breslow_score_wrapper), - cv=StratifiedSurvivalKFold( - n_splits=config["n_inner_cv"], - shuffle=config["shuffle_cv"], - random_state=config["random_state"], - ), - error_score=config["error_score"], - verbose=False, - n_iter=config["random_search_n_iter"], - ), - student=make_pipeline( - VarianceThreshold(), - StandardScaler(), - KDPHElasticNetCV( - tie_correction="breslow", - l1_ratio=config[f"l1_ratio{'_tuned' if tune_l1_ratio else ''}"], - eps=config["eps"], - n_alphas=config["n_alphas"], - cv=config["n_inner_cv"], - stratify_cv=config["stratify_cv"], - seed=np.random.RandomState(config["random_state"]), - shuffle_cv=config["shuffle_cv"], - cv_score_method="linear_predictor", - n_jobs=1, - ), - ), - ) - start = timer() - pipe.fit(X_, y_) - end = timer() - timing[cancer].append(end - start) - - if tune_l1_ratio: - pd.DataFrame(timing).to_csv( - f"./results/kd/cox_nnet/timing_tuned_l1_ratio.csv", - index=False, - ) - - else: - pd.DataFrame(timing).to_csv( - f"./results/kd/cox_nnet/timing.csv", - index=False, - ) diff --git a/paper/scripts/r/make_table_S2.R b/paper/scripts/r/make_table_S2.R deleted file mode 100644 index f0234c1..0000000 --- a/paper/scripts/r/make_table_S2.R +++ /dev/null @@ -1,56 +0,0 @@ -library(ggplot2) -library(dplyr) -library(vroom) -library(cowplot) -library(ggpubfigs) -library(ggsignif) -options(warn = 1) - - -config <- rjson::fromJSON( - file = here::here( - "config.json" - ) -) - - -sparsity <- data.frame( - sparsity = c( - c( - unlist(as.vector(vroom::vroom( - here::here( - "results", "non_kd", "breslow", "sparsity_vvh_lambda.min.csv" - ) - )[1:25, ])) - ), - unlist(as.vector(vroom::vroom( - here::here( - "results", "non_kd", "breslow", "sparsity_tuned_l1_ratio_vvh_lambda.min.csv" - ) - ))), - unlist(as.vector(vroom::vroom( - here::here( - "results", "kd", "breslow", "sparsity_linear_predictor_min.csv" - ) - )[1:25, ])), - unlist(as.vector(vroom::vroom( - here::here( - "results", "kd", "breslow", "sparsity_linear_predictor_pcvl.csv" - ) - ))), - unlist(as.vector(vroom::vroom( - here::here( - "results", "kd", "cox_nnet", "sparsity_linear_predictor_min.csv" - ) - ))) - ), - cancer = rep(rep(config$datasets, each = 25), 9), - model = rep(c("glmnet (Breslow)", "glmnet tuned (Breslow)", "KD Breslow (min)", "KD Breslow (pcvl)", "KD Cox-Nnet (min)"), each = 225) -) - - -sparsity %>% - group_by(model, cancer) %>% - summarise(value = paste0(round(mean(sparsity), 2), " (", round(sd(sparsity), 2), ")")) %>% - pivot_wider(names_from = cancer, values_from = value) %>% - write_csv(here::here("tables", "table_S2.csv")) diff --git a/paper/scripts/r/make_table_S3.R b/paper/scripts/r/make_table_S3.R deleted file mode 100644 index 42e3715..0000000 --- a/paper/scripts/r/make_table_S3.R +++ /dev/null @@ -1,48 +0,0 @@ -library(ggplot2) -library(dplyr) -library(vroom) -library(cowplot) -library(ggpubfigs) -library(ggsignif) -library(readr) -library(tidyr) -options(warn = 1) - - -config <- rjson::fromJSON( - file = here::here( - "config.json" - ) -) - -metrics <- vroom::vroom(here::here("results", "metrics", "metrics_overall.csv")) - - -regular_metrics <- metrics %>% - filter(model %in% c("breslow", "cox_nnet")) %>% - filter(lambda %in% c("min", "lambda.min", "pcvl")) %>% - filter(metric %in% c("IBS")) -regular_metrics$model_type <- ifelse(regular_metrics$model == "breslow" & regular_metrics$lambda == "pcvl", "KD Breslow (pcvl)", - ifelse(regular_metrics$model == "breslow" & regular_metrics$kd, - "KD Breslow (min)", - ifelse( - regular_metrics$model == "cox_nnet", - "KD Cox-Nnet (min)", - ifelse( - regular_metrics$tuned, "glmnet tuned (Breslow)", - "glmnet (Breslow)" - ) - ) - ) -) -ibs_metrics <- metrics %>% - filter(model == "breslow" & !kd & lambda == 0) %>% - filter(metric %in% c("IBS")) - -regular_metrics %>% - left_join(ibs_metrics, by = c("cancer" = "cancer", "split" = "split")) %>% - mutate(model = model_type, cancer = cancer, split = split, is_kd = value.x == value.y) %>% - group_by(model, cancer) %>% - summarise(sum = sum(is_kd)) %>% - pivot_wider(names_from = cancer, values_from = sum) %>% - write_csv(here::here("tables", "table_S3.csv")) diff --git a/paper/scripts/r/plot_figure_S2.R b/paper/scripts/r/plot_figure_S2.R deleted file mode 100644 index 472267d..0000000 --- a/paper/scripts/r/plot_figure_S2.R +++ /dev/null @@ -1,119 +0,0 @@ -library(ggplot2) -library(dplyr) -library(vroom) -library(cowplot) -library(ggpubfigs) -library(ggsignif) -options(warn = 1) - - -config <- rjson::fromJSON( - file = here::here( - "config.json" - ) -) - - -metrics <- vroom::vroom(here::here("results", "metrics", "metrics_overall.csv")) - -metrics %>% - filter(score %in% c("path")) %>% - filter((metric == "Antolini's C" & model == "breslow" & !kd) | (kd & model == "cox_nnet" & metric == "Antolini's C")) -> path_data -teacher_line <- metrics %>% - filter(score %in% c("teacher")) %>% - filter(metric == "Antolini's C") %>% - filter(model == "cox_nnet") %>% - group_by(cancer) %>% - summarise(mean = mean(value)) -path_data$cancer <- factor(path_data$cancer, levels = cancer_ordering) - - -path_data$model_type <- ifelse(path_data$kd, "KD Cox-Nnet", - "glmnet (Breslow)" -) - - -path_data$model_type <- factor(path_data$model_type, levels = c("glmnet (Breslow)", "KD Cox-Nnet")) -path_data$cancer <- factor(path_data$cancer, as.character(cancer_ordering)) -teacher_line$cancer <- factor(teacher_line$cancer, as.character(cancer_ordering)) -path_data_summarised <- path_data %>% - group_by(cancer, model_type, lambda) %>% - summarise(mean = mean(value), sd = sd(value) / sqrt(n())) - -path_data_summarised$cancer <- factor(path_data_summarised$cancer, levels = cancer_ordering) -g <- ggplot(path_data_summarised, aes(x = as.numeric(lambda), group = model_type)) + - geom_line(aes(y = mean, color = model_type), linewidth = 1) + - geom_ribbon(aes(y = mean, ymin = mean - sd, ymax = mean + sd, fill = model_type), alpha = .1) + - scale_color_manual(values = ggpubfigs::friendly_pals$ito_seven[c(1, 6)]) + - scale_fill_manual(values = ggpubfigs::friendly_pals$ito_seven[c(1, 6)]) + - geom_hline(data = teacher_line, aes(yintercept = mean, linetype = "Cox-Nnet teacher"), color = "red", lwd = 0.5, linetype = 2, show.legend = FALSE, alpha = 0.75) + - facet_wrap(~cancer, scales = "free_y", nrow = 2) + - theme_big_simple() + - labs(x = "Regularization index (from sparse to dense)", y = "Antolini's C", fill = "", color = "") - - - -p <- ggplot(mtcars, aes(x = wt, y = mpg)) + - geom_point() -teacher_legend <- p + geom_hline(aes(lty = "Cox-Nnet teacher teacher", yintercept = 20), linewidth = 1, color = "red", show_guide = TRUE) + scale_linetype_manual(name = "", values = 2) + theme_big_simple() + guides(color = guide_legend(override.aes = list(linetype = c("dashed")))) + theme(legend.key.width = unit(2, "cm")) - - -metrics <- vroom::vroom(here::here("results", "metrics", "metrics_overall_fixed.csv")) - -metrics %>% - filter(score %in% c("path")) %>% - filter((metric == "IBS" & model == "breslow" & !kd) | (kd & model == "cox_nnet" & metric == "IBS")) -> path_data -teacher_line <- metrics %>% - filter(score %in% c("teacher")) %>% - filter(metric == "IBS") %>% - filter(model == "cox_nnet") %>% - group_by(cancer) %>% - summarise(mean = mean(value)) - - -path_data$model_type <- ifelse(path_data$kd, "KD Cox-Nnet", - "glmnet (Breslow)" -) - - -path_data$model_type <- factor(path_data$model_type, levels = c("glmnet (Breslow)", "KD Cox-Nnet")) -path_data$cancer <- factor(path_data$cancer, levels = as.character(cancer_ordering)) -teacher_line$cancer <- factor(teacher_line$cancer, as.character(cancer_ordering)) -path_data_summarised <- path_data %>% - group_by(cancer, model_type, lambda) %>% - summarise(mean = mean(value), sd = sd(value) / sqrt(n())) - - -h <- ggplot(path_data_summarised, aes(x = as.numeric(lambda), group = model_type)) + - geom_line(aes(y = mean, color = model_type), linewidth = 1) + - geom_ribbon(aes(y = mean, ymin = mean - sd, ymax = mean + sd, fill = model_type), alpha = .1) + - scale_color_manual(values = ggpubfigs::friendly_pals$ito_seven[c(1, 6)]) + - scale_fill_manual(values = ggpubfigs::friendly_pals$ito_seven[c(1, 6)]) + - geom_hline(data = teacher_line, aes(yintercept = mean, linetype = "Cox-Nnet teacher"), color = "red", lwd = 0.5, linetype = 2, show.legend = FALSE, alpha = 0.75) + - facet_wrap(~cancer, scales = "free_y", nrow = 2) + - theme_big_simple() + - labs(x = "Regularization index (from sparse to dense)", y = "Integrated Brier Score", fill = "", color = "") - - -line_legend <- get_legend( - g + theme(legend.box.margin = margin(0, 0, 0, 0)) -) - -teacher_legend <- get_legend( - teacher_legend + theme(legend.box.margin = margin(0, 0, 0, 0)) -) - -both_legends <- plot_grid( - line_legend, teacher_legend -) - -reg_path <- plot_grid( - cowplot::plot_grid(g + theme(legend.position = "none"), both_legends, rel_heights = c(0.95, 0.1), nrow = 2, ncol = 1), - cowplot::plot_grid(h + theme(legend.position = "none"), both_legends, rel_heights = c(0.95, 0.1), nrow = 2, ncol = 1), - labels = c("A", "B"), - nrow = 2, - label_size = 24 -) - -ggsave(here::here("figures", "fig-S2_finalized.pdf"), plot = reg_path, dpi = 300, height = 20 / 1.75, width = 15, units = "in") -ggsave(here::here("figures", "fig-S2_finalized.svg"), plot = reg_path, dpi = 300, height = 20 / 1.75, width = 15, units = "in") diff --git a/paper/scripts/r/plot_figure_S3.R b/paper/scripts/r/plot_figure_S3.R deleted file mode 100644 index 193eefa..0000000 --- a/paper/scripts/r/plot_figure_S3.R +++ /dev/null @@ -1,79 +0,0 @@ -library(ggplot2) -library(dplyr) -library(vroom) -library(cowplot) -library(ggpubfigs) -library(ggsignif) -options(warn = 1) - -config <- rjson::fromJSON( - file = here::here( - "config.json" - ) -) - -metrics <- vroom::vroom(here::here("results", "metrics", "metrics_overall.csv")) -metrics_125 <- vroom::vroom(here::here("results", "metrics", "metrics_overall_125_full.csv")) -metrics_stratified <- vroom::vroom(here::here("results", "metrics", "metrics_overall_cved.csv")) - -metrics <- rbind( - cbind(metrics %>% filter(model == "breslow" & lambda %in% c("lambda.min", "min")), - calc_type = "5-fold CV 5 reps (per split)" - ), - cbind(metrics_125, calc_type = "5-fold CV 25 reps (per split)"), - cbind(metrics_stratified, calc_type = "5-fold CV 25 reps (per CV)") -) - -metrics$model_type <- ifelse( - metrics$kd, "KD Breslow (min)", - ifelse( - metrics$tuned, "glmnet tuned (Breslow)", - "glmnet (Breslow)" - ) -) - -a <- metrics %>% - filter(metric == "Harrell's C") %>% - ggplot(aes(x = model_type, y = value, fill = model_type)) + - geom_boxplot() + - theme_big_simple() + - labs(x = "", y = "Harrell's C", fill = "") + - facet_wrap(~ interaction(calc_type)) + - theme( - axis.title.x = element_blank(), - axis.text.x = element_blank(), - axis.ticks.x = element_blank() - ) + - scale_fill_manual(values = ggpubfigs::friendly_pals$ito_seven[c(1, 2, 4, 5, 6, 7)]) - -metrics$model_type <- ifelse( - metrics$kd, "KD Breslow (min)", - ifelse( - metrics$tuned, "glmnet tuned (Breslow)", - "glmnet (Breslow)" - ) -) - -b <- metrics %>% - filter(metric == "Uno's C") %>% - ggplot(aes(x = model_type, y = value, fill = model_type)) + - geom_boxplot() + - theme_big_simple() + - labs(x = "", y = "Harrell's C", fill = "") + - facet_wrap(~ interaction(calc_type)) + - theme( - axis.title.x = element_blank(), - axis.text.x = element_blank(), - axis.ticks.x = element_blank() - ) + - scale_fill_manual(values = ggpubfigs::friendly_pals$ito_seven[c(1, 2, 4, 5, 6, 7)]) - - -cv_fig <- cowplot::plot_grid(a, b, - labels = c("A", "B"), - nrow = 2, - label_size = 24 -) - -ggsave(here::here("figures", "fig-S3_finalized.pdf"), plot = cv_fig, dpi = 300, height = 20 / 1.5, width = 15, units = "in") -ggsave(here::here("figures", "fig-S3_finalized.svg"), plot = cv_fig, dpi = 300, height = 20 / 1.5, width = 15, units = "in") diff --git a/paper/scripts/r/time_glmnet.R b/paper/scripts/r/time_glmnet.R deleted file mode 100644 index 6fe0244..0000000 --- a/paper/scripts/r/time_glmnet.R +++ /dev/null @@ -1,107 +0,0 @@ -library(rjson) -library(glmnet) -library(penAFT) -library(survival) -library(coefplot) -library(pec) -library(readr) -library(vroom) -library(dplyr) -library(microbenchmark) -library(splitTools) -library(glmnetUtils) - - -# Prevent early stoping. -glmnet::glmnet.control( - fdev = 0, - devmax = 1.0 -) - -config <- rjson::fromJSON( - file = here::here( - "config.json" - ) -) - -set.seed(config$seed) - -# https://stackoverflow.com/questions/7196450/create-a-dataframe-of-unequal-lengths -na.pad <- function(x, len) { - x[1:len] -} - -makePaddedDataFrame <- function(l, ...) { - maxlen <- max(sapply(l, length)) - data.frame(lapply(l, na.pad, len = maxlen), ...) -} - -timing <- list() - -for (tune_l1_ratio in c(TRUE)) { - for (cancer in c(config$datasets)) { - timing[[cancer]] <- c() - data <- data.frame(vroom::vroom( - here::here( - "data", "processed", "TCGA", - paste0(cancer, "_data_preprocessed.csv") - ) - )[, -1], check.names = FALSE) - x <- as.matrix(data[, -(1:2)]) - y <- Surv(data$OS_days, data$OS) - fold_ids <- rep(0, length(y)) - - fold_helper <- create_folds( - y = data$OS, - k = config$n_inner_cv, - type = c("stratified"), - invert = TRUE, - seed = config$seed - ) - for (i in 1:length(fold_helper)) { - fold_ids[fold_helper[[i]]] <- i - } - if (tune_l1_ratio) { - tim <- microbenchmark( - cva.glmnet( - x = x, - y = y, - family = "cox", - alpha = config$l1_ratio_tuned, - lambda.min.ratio = config$eps, - standardize = TRUE, - nlambda = config$n_alphas, - foldid = fold_ids, - grouped = TRUE, - - ), - times = config$timing_reps - ) - } - else{ - tim <- microbenchmark( - cv.glmnet( - x = x, - y = y, - family = "cox", - alpha = config$l1_ratio, - lambda.min.ratio = config$eps, - standardize = TRUE, - nlambda = config$n_alphas, - foldid = fold_ids, - grouped = TRUE - ), - times = config$timing_reps - ) - } - - timing[[cancer]] <- tim$time * 1e-9 - } - if (tune_l1_ratio) { - data.frame(timing) %>% write_csv(here::here("results", "non_kd", "breslow", "timing_tuned_l1_ratio.csv")) - } - else { - data.frame(timing) %>% write_csv(here::here("results", "non_kd", "breslow", "timing.csv")) - } - -} diff --git a/paper/scripts/sh/create_paths.sh b/paper/scripts/sh/create_paths.sh deleted file mode 100644 index 78f7f88..0000000 --- a/paper/scripts/sh/create_paths.sh +++ /dev/null @@ -1,20 +0,0 @@ -#!/bin/bash - -for model in "breslow" "cox_nnet"; do - for cancer in "BLCA" "BRCA" "HNSC" "KIRC" "LGG" \ - "LIHC" "LUAD" "LUSC" "OV" "STAD"; do - mkdir -p ./results/kd/$model/$cancer - mkdir -p ./results/kd/$model/$cancer/path - - done -done - -for model in "breslow"; do - for cancer in "BLCA" "BRCA" "HNSC" "KIRC" "LGG" \ - "LIHC" "LUAD" "LUSC" "OV" "STAD"; do - mkdir -p ./results/non_kd/$model/$cancer - mkdir -p ./results/non_kd/$model/$cancer/path - done -done - -mkdir -p ./results/metrics/ diff --git a/paper/scripts/sh/download_data.sh b/paper/scripts/sh/download_data.sh deleted file mode 100644 index ce1a8fe..0000000 --- a/paper/scripts/sh/download_data.sh +++ /dev/null @@ -1,11 +0,0 @@ -#!/bin/bash - -mkdir -p ./data -mkdir -p ./data/raw - -for id in "3586c0da-64d0-4b74-a449-5ff4d9136611" \ - "1b5f413e-a8d1-4d10-92eb-7c4ae739ed81" \ - "0fc78496-818b-4896-bd83-52db1f533c5c"; do - wget --content-disposition http://api.gdc.cancer.gov/data/${id} -P ./data/raw/ - sleep 60 -done diff --git a/paper/scripts/sh/make_splits.sh b/paper/scripts/sh/make_splits.sh deleted file mode 100644 index 843ae4b..0000000 --- a/paper/scripts/sh/make_splits.sh +++ /dev/null @@ -1,7 +0,0 @@ -#!/bin/bash - -mkdir -p ./data/splits - -python -u scripts/py/rerun_splits.py \ - --data_dir ./data/ \ - --config_path ./ diff --git a/paper/scripts/sh/preprocess_data.sh b/paper/scripts/sh/preprocess_data.sh deleted file mode 100644 index 44ceb6a..0000000 --- a/paper/scripts/sh/preprocess_data.sh +++ /dev/null @@ -1,6 +0,0 @@ -#!/bin/bash - -mkdir -p ./data/processed -mkdir -p ./data/processed/TCGA - -Rscript scripts/r/run_preprocessing.R diff --git a/paper/scripts/sh/remake_figures_and_tables.sh b/paper/scripts/sh/remake_figures_and_tables.sh deleted file mode 100644 index 7e6f3fd..0000000 --- a/paper/scripts/sh/remake_figures_and_tables.sh +++ /dev/null @@ -1,13 +0,0 @@ -#!/bin/bash - -Rscript scripts/r/plot_figure_1.R - -Rscript scripts/r/plot_figure_S1.R -Rscript scripts/r/plot_figure_S2.R -Rscript scripts/r/plot_figure_S3.R - -Rscript scripts/r/make_table_S2.R -Rscript scripts/r/make_table_S3.R -Rscript scripts/r/make_table_S4.R - -python scripts/py/make_table_S1.py diff --git a/paper/scripts/sh/rerun_experiments.sh b/paper/scripts/sh/rerun_experiments.sh deleted file mode 100644 index 06fb287..0000000 --- a/paper/scripts/sh/rerun_experiments.sh +++ /dev/null @@ -1,12 +0,0 @@ -#!/bin/bash - -Rscript scripts/r/time_glmnet.R -python scripts/py/time_sparsesurv.py - -Rscript scripts/r/run_path_glmnet.R -python scripts/py/run_path_sparsesurv.py - -Rscript scripts/r/run_glmnet.R -python scripts/py/run_sparsesurv.py - -python scripts/py/run_teachers.py diff --git a/setup.py b/setup.py index 2c7bb6c..01bf935 100644 --- a/setup.py +++ b/setup.py @@ -49,7 +49,7 @@ def read_version(filepath: str) -> str: author_email="dwissel@inf.ethz.ch, jnikita@inf.ethz.ch", # the following exclusion is to prevent shipping of tests. # if you do include them, add pytest to the required packages. - packages=find_packages(".", exclude=["*tests*"]), + packages=["sparsesurv", "sparsesurv.neuralsurv", "sparsesurv.neuralsurv.python", "sparsesurv.neuralsurv.python.utils", "sparsesurv.neuralsurv.python.model"], package_data={"sparsesurv": ["py.typed"]}, extras_require={ "vcs": VCS_REQUIREMENTS, diff --git a/sparsesurv/neuralsurv/python/__init__.py b/sparsesurv/neuralsurv/python/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/sparsesurv/neuralsurv/python/model/__init__.py b/sparsesurv/neuralsurv/python/model/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/sparsesurv/neuralsurv/python/utils/__init__.py b/sparsesurv/neuralsurv/python/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/workflow/README.md b/workflow/README.md new file mode 100644 index 0000000..c4ac890 --- /dev/null +++ b/workflow/README.md @@ -0,0 +1,30 @@ +# sparsesurv: A Python package for fitting sparse survival models via knowledge distillation + +## Abstract + +Sparse survival models are statistical models that select a subset of predictor variables while modeling the time until an event occurs, which can subsequently help interpretability and transportability. The subset of important features is often obtained with regularized models, such as the Cox Proportional Hazards model with Lasso regularization, which limit the number of non-zero coefficients. However, such models can be sensitive to the choice of regularization hyperparameter. In this work, we develop a software package and demonstrate how knowledge distillation, a powerful technique in machine learning that aims to transfer knowledge from a complex teacher model to a simpler student model, can be leveraged to learn sparse survival models while mitigating this challenge. For this purpose, we present sparsesurv, a Python package that contains a set of teacher-student model pairs, including the semi-parametric accelerated failure time and the extended hazards models as teachers, which currently do not have Python implementations. It also contains in-house survival function estimators, removing the need for external packages. Sparsesurv is validated against R-based Elastic Net regularized linear Cox proportional hazards models as implemented in the commonly used glmnet package. Our results reveal that knowledge distillation-based approaches achieve competitive discriminative performance relative to glmnet across the regularization path while making the choice of the regularization hyperparameter significantly easier. All of these features, combined with an sklearn-like API, make sparsesurv an easy-to-use Python package that enables survival analysis for high-dimensional datasets through fitting sparse survival models via knowledge distillation. + +## Reproducibility + +### From scratch + +``` +snakemake --use-conda --conda-frontend mamba --cores 12 +``` + +### Results + +All of our results, including preprocessed data, computed performance metrics and predicted survival functions for all models and experiments are available on [Zenodo](https://zenodo.org/doi/10.5281/zenodo.8280014). + +## Questions + +In case of any questions, please reach out to david.wissel@inf.ethz.ch or open an issue in this repo. + +## Citation + +Our manuscript is still under review. + +## References + +[1] Hinton, Geoffrey, Oriol Vinyals, and Jeff Dean. "Distilling the knowledge in a neural network." arXiv preprint arXiv:1503.02531 (2015). +[2] Paul, Debashis, et al. "" Preconditioning" for Feature Selection and Regression in High-Dimensional Problems." The Annals of Statistics (2008): 1595-1618. diff --git a/workflow/Snakefile b/workflow/Snakefile new file mode 100644 index 0000000..1f4025e --- /dev/null +++ b/workflow/Snakefile @@ -0,0 +1,517 @@ +from snakemake.utils import min_version + + +configfile: "config/config.yaml" + + +min_version(config["snakemake_min_version"]) + + +container: f"docker://condaforge/mambaforge:{config['mambaforge_version']}" + + +datasets = config["datasets"] + + +rule all: + input: + "results/figures/fig-1_finalized.svg", + "results/figures/fig-1_finalized.png", + "results/figures/fig-S1_finalized.svg", + "results/figures/fig-S1_finalized.pdf", + "results/figures/fig-S2_finalized.svg", + "results/figures/fig-S2_finalized.pdf", + "results/figures/fig-S3_finalized.svg", + "results/figures/fig-S3_finalized.pdf", + "results/tables/table_S1.csv", + "results/tables/table_S2.csv", + "results/tables/table_S3.csv", + "results/tables/table_S4.csv", + + +rule download_data: + output: + gex="results/download_data/gex.tsv", + cdr="results/download_data/cdr.xlsx", + followup="results/download_data/followup.tsv", + log: + "logs/download_data/log.out", + shell: + """ + wget --content-disposition http://api.gdc.cancer.gov/data/3586c0da-64d0-4b74-a449-5ff4d9136611 -O {output.gex} &> {log}; + wget --content-disposition http://api.gdc.cancer.gov/data/1b5f413e-a8d1-4d10-92eb-7c4ae739ed81 -O {output.cdr} &>> {log}; + wget --content-disposition http://api.gdc.cancer.gov/data/0fc78496-818b-4896-bd83-52db1f533c5c -O {output.followup} &>> {log} + """ + + +rule preprocess_data: + input: + gex="results/download_data/gex.tsv", + cdr="results/download_data/cdr.xlsx", + followup="results/download_data/followup.tsv", + output: + output_path="results/preprocess_data/{datasets}.csv", + params: + cancer="{datasets}", + log: + "logs/preprocess_data/{datasets}.out", + conda: + "envs/preprocess.yaml" + script: + "scripts/r/preprocess_data.R" + + +rule make_splits: + input: + data_path="results/preprocess_data/{cancer}.csv", + output: + train_splits_output_path="results/make_splits/{cancer}_train_splits.csv", + test_splits_output_path="results/make_splits/{cancer}_test_splits.csv", + params: + random_seed=config["random_seed"], + n_outer_repetitions=config["n_outer_repetitions"], + n_outer_splits=config["n_outer_splits"], + log: + "logs/make_splits/{cancer}.out", + conda: + "envs/make_splits.yaml" + script: + "scripts/py/make_splits.py" + + +rule time_glmnet: + input: + data_path=expand("results/preprocess_data/{cancer}.csv", cancer=datasets), + output: + "results/non_kd/breslow/timing_tuned_l1_ratio.csv", + "results/non_kd/breslow/timing.csv", + params: + config_path=config["config_path"], + threads: 1 + log: + "logs/make_splits/log.out", + conda: + "envs/glmnet.yaml" + script: + "scripts/r/time_glmnet.R" + + +rule run_path_glmnet: + input: + data_path=expand("results/preprocess_data/{cancer}.csv", cancer=datasets), + train_splits_output_path=expand( + "results/make_splits/{cancer}_train_splits.csv", cancer=datasets + ), + test_splits_output_path=expand( + "results/make_splits/{cancer}_test_splits.csv", cancer=datasets + ), + output: + expand( + "results/non_kd/breslow/{cancer}/path/sparsity.csv", + cancer=datasets, + ), + params: + config_path=config["config_path"], + threads: 1 + log: + "logs/run_path_glmnet/log.out", + conda: + "envs/glmnet.yaml" + script: + "scripts/r/run_path_glmnet.R" + + +rule run_glmnet: + input: + data_path=expand("results/preprocess_data/{cancer}.csv", cancer=datasets), + train_splits_output_path=expand( + "results/make_splits/{cancer}_train_splits.csv", cancer=datasets + ), + test_splits_output_path=expand( + "results/make_splits/{cancer}_test_splits.csv", cancer=datasets + ), + output: + "results/non_kd/breslow/sparsity_vvh_lambda.min.csv", + params: + config_path=config["config_path"], + threads: 1 + log: + "logs/run_glmnet/log.out", + conda: + "envs/glmnet.yaml" + script: + "scripts/r/run_glmnet.R" + + +rule run_glmnet_tuned: + input: + data_path="results/preprocess_data/{cancer}.csv", + train_splits_output_path="results/make_splits/{cancer}_train_splits.csv", + test_splits_output_path="results/make_splits/{cancer}_test_splits.csv", + output: + "results/non_kd/breslow/{cancer}/sparsity_tuned_l1_ratio_vvh_lambda.min.csv", + params: + config_path=config["config_path"], + cancer="{cancer}", + threads: 1 + log: + "logs/run_glmnet_tuned/{cancer}/log.out", + conda: + "envs/glmnet.yaml" + script: + "scripts/r/run_glmnet_tuned.R" + + +rule time_sparsesurv: + input: + data_path=expand("results/preprocess_data/{cancer}.csv", cancer=datasets), + output: + "results/kd/cox_nnet/timing.csv", + "results/kd/breslow/timing_tuned_teacher.csv", + params: + config_path=config["config_path"], + threads: 1 + log: + "logs/time_sparsesurv/log.out", + conda: + "envs/sparsesurv.yaml" + script: + "scripts/py/time_sparsesurv.py" + + +rule run_teachers: + input: + data_path=expand("results/preprocess_data/{cancer}.csv", cancer=datasets), + train_splits_output_path=expand( + "results/make_splits/{cancer}_train_splits.csv", cancer=datasets + ), + test_splits_output_path=expand( + "results/make_splits/{cancer}_test_splits.csv", cancer=datasets + ), + output: + expand("results/kd/cox_nnet/{cancer}/eta_teacher.csv", cancer=datasets), + expand("results/kd/breslow/{cancer}/eta_teacher.csv", cancer=datasets), + params: + config_path=config["config_path"], + threads: 12 + log: + "logs/run_teachers/log.out", + conda: + "envs/sparsesurv.yaml" + script: + "scripts/py/run_teachers.py" + + +rule run_path_sparsesurv: + input: + data_path=expand("results/preprocess_data/{cancer}.csv", cancer=datasets), + train_splits_output_path=expand( + "results/make_splits/{cancer}_train_splits.csv", cancer=datasets + ), + test_splits_output_path=expand( + "results/make_splits/{cancer}_test_splits.csv", cancer=datasets + ), + output: + expand("results/kd/cox_nnet/{cancer}/path/sparsity.csv", cancer=datasets), + expand("results/kd/breslow/{cancer}/path/sparsity.csv", cancer=datasets), + params: + config_path=config["config_path"], + threads: 12 + log: + "logs/run_path_sparsesurv/log.out", + conda: + "envs/sparsesurv.yaml" + script: + "scripts/py/run_path_sparsesurv.py" + + +rule run_sparsesurv: + input: + data_path=expand("results/preprocess_data/{cancer}.csv", cancer=datasets), + train_splits_output_path=expand( + "results/make_splits/{cancer}_train_splits.csv", cancer=datasets + ), + test_splits_output_path=expand( + "results/make_splits/{cancer}_test_splits.csv", cancer=datasets + ), + output: + expand( + "results/kd/cox_nnet/{cancer}/eta_linear_predictor_min.csv", + cancer=datasets, + ), + expand( + "results/kd/breslow/{cancer}/eta_linear_predictor_min.csv", + cancer=datasets, + ), + expand( + "results/kd/breslow/{cancer}/eta_linear_predictor_pcvl.csv", + cancer=datasets, + ), + params: + config_path=config["config_path"], + threads: 12 + log: + "logs/run_sparsesurv/log.out", + conda: + "envs/sparsesurv.yaml" + script: + "scripts/py/run_sparsesurv.py" + + +rule make_metrics_teachers: + input: + expand("results/kd/cox_nnet/{cancer}/eta_teacher.csv", cancer=datasets), + expand("results/kd/breslow/{cancer}/eta_teacher.csv", cancer=datasets), + output: + "results/metrics/metrics_overall_teachers.csv", + params: + config_path=config["config_path"], + threads: 1 + log: + "logs/make_metrics_teachers/log.out", + conda: + "envs/make_metrics.yaml" + script: + "scripts/py/make_metrics_teachers.py" + + +rule make_metrics_cved: + input: + expand( + "results/kd/cox_nnet/{cancer}/eta_linear_predictor_min.csv", + cancer=datasets, + ), + expand( + "results/kd/breslow/{cancer}/eta_linear_predictor_min.csv", + cancer=datasets, + ), + expand( + "results/kd/breslow/{cancer}/eta_linear_predictor_pcvl.csv", + cancer=datasets, + ), + expand("results/kd/cox_nnet/{cancer}/path/sparsity.csv", cancer=datasets), + expand("results/kd/breslow/{cancer}/path/sparsity.csv", cancer=datasets), + expand( + "results/non_kd/breslow/{cancer}/path/sparsity.csv", + cancer=datasets, + ), + expand( + "results/non_kd/breslow/{cancer}/sparsity_tuned_l1_ratio_vvh_lambda.min.csv", + cancer=datasets, + ), + expand( + "results/non_kd/breslow/sparsity_vvh_lambda.min.csv", + cancer=datasets, + ), + output: + "results/metrics/metrics_overall_cved.csv", + "results/metrics/metrics_overall_125_full.csv", + params: + config_path=config["config_path"], + threads: 1 + log: + "logs/make_metrics_cved/log.out", + conda: + "envs/make_metrics.yaml" + script: + "scripts/py/make_metrics_cved.py" + + +rule make_metrics: + input: + expand( + "results/kd/cox_nnet/{cancer}/eta_linear_predictor_min.csv", + cancer=datasets, + ), + expand( + "results/kd/breslow/{cancer}/eta_linear_predictor_min.csv", + cancer=datasets, + ), + expand( + "results/kd/breslow/{cancer}/eta_linear_predictor_pcvl.csv", + cancer=datasets, + ), + expand("results/kd/cox_nnet/{cancer}/path/sparsity.csv", cancer=datasets), + expand("results/kd/breslow/{cancer}/path/sparsity.csv", cancer=datasets), + expand( + "results/non_kd/breslow/{cancer}/path/sparsity.csv", + cancer=datasets, + ), + expand( + "results/non_kd/breslow/{cancer}/sparsity_tuned_l1_ratio_vvh_lambda.min.csv", + cancer=datasets, + ), + expand( + "results/non_kd/breslow/sparsity_vvh_lambda.min.csv", + cancer=datasets, + ), + output: + "results/metrics/metrics_overall.csv", + params: + config_path=config["config_path"], + threads: 1 + log: + "logs/make_metrics/log.out", + conda: + "envs/make_metrics.yaml" + script: + "scripts/py/make_metrics.py" + + +rule make_table_S1: + input: + data_path=expand("results/preprocess_data/{cancer}.csv", cancer=datasets), + train_splits_output_path=expand( + "results/make_splits/{cancer}_train_splits.csv", cancer=datasets + ), + test_splits_output_path=expand( + "results/make_splits/{cancer}_test_splits.csv", cancer=datasets + ), + output: + "results/tables/table_S1.csv", + params: + config_path=config["config_path"], + threads: 1 + log: + "logs/make_table_S1/log.out", + conda: + "envs/make_metrics.yaml" + script: + "scripts/py/make_table_S1.py" + + +rule make_table_S2: + input: + "results/metrics/metrics_overall.csv", + "results/metrics/metrics_overall_cved.csv", + "results/metrics/metrics_overall_125_full.csv", + "results/metrics/metrics_overall_teachers.csv", + output: + "results/tables/table_S2.csv", + params: + config_path=config["config_path"], + threads: 1 + log: + "logs/make_table_S2/log.out", + conda: + "envs/plots_and_tables.yaml" + script: + "scripts/r/make_table_S2.R" + + +rule make_table_S3: + input: + "results/metrics/metrics_overall.csv", + "results/metrics/metrics_overall_cved.csv", + "results/metrics/metrics_overall_125_full.csv", + "results/metrics/metrics_overall_teachers.csv", + output: + "results/tables/table_S3.csv", + params: + config_path=config["config_path"], + threads: 1 + log: + "logs/make_table_S3/log.out", + conda: + "envs/plots_and_tables.yaml" + script: + "scripts/r/make_table_S3.R" + + +rule make_table_S4: + input: + "results/metrics/metrics_overall.csv", + "results/metrics/metrics_overall_cved.csv", + "results/metrics/metrics_overall_125_full.csv", + "results/metrics/metrics_overall_teachers.csv", + output: + "results/tables/table_S4.csv", + params: + config_path=config["config_path"], + threads: 1 + log: + "logs/make_table_S4/log.out", + conda: + "envs/plots_and_tables.yaml" + script: + "scripts/r/make_table_S4.R" + + +rule plot_figure_1: + input: + "results/metrics/metrics_overall.csv", + "results/metrics/metrics_overall_cved.csv", + "results/metrics/metrics_overall_125_full.csv", + "results/metrics/metrics_overall_teachers.csv", + output: + "results/figures/fig-1_finalized.svg", + "results/figures/fig-1_finalized.png", + params: + config_path=config["config_path"], + threads: 1 + log: + "logs/plot_figure_1/log.out", + conda: + "envs/plots_and_tables.yaml" + script: + "scripts/r/plot_figure_1.R" + + +rule plot_figure_S1: + input: + "results/metrics/metrics_overall.csv", + "results/metrics/metrics_overall_cved.csv", + "results/metrics/metrics_overall_125_full.csv", + "results/metrics/metrics_overall_teachers.csv", + output: + "results/figures/fig-S1_finalized.svg", + "results/figures/fig-S1_finalized.pdf", + params: + config_path=config["config_path"], + threads: 1 + log: + "logs/plot_figure_S1/log.out", + conda: + "envs/plots_and_tables.yaml" + script: + "scripts/r/plot_figure_S1.R" + + +rule plot_figure_S2: + input: + "results/metrics/metrics_overall.csv", + "results/metrics/metrics_overall_cved.csv", + "results/metrics/metrics_overall_125_full.csv", + "results/metrics/metrics_overall_teachers.csv", + output: + "results/figures/fig-S2_finalized.svg", + "results/figures/fig-S2_finalized.pdf", + params: + config_path=config["config_path"], + threads: 1 + log: + "logs/plot_figure_S2/log.out", + conda: + "envs/plots_and_tables.yaml" + script: + "scripts/r/plot_figure_S2.R" + + +rule plot_figure_S3: + input: + "results/metrics/metrics_overall.csv", + "results/metrics/metrics_overall_cved.csv", + "results/metrics/metrics_overall_125_full.csv", + "results/metrics/metrics_overall_teachers.csv", + output: + "results/figures/fig-S3_finalized.svg", + "results/figures/fig-S3_finalized.pdf", + params: + config_path=config["config_path"], + threads: 1 + log: + "logs/plot_figure_S3/log.out", + conda: + "envs/plots_and_tables.yaml" + script: + "scripts/r/plot_figure_S3.R" diff --git a/workflow/envs/glmnet.yaml b/workflow/envs/glmnet.yaml new file mode 100644 index 0000000..5b863be --- /dev/null +++ b/workflow/envs/glmnet.yaml @@ -0,0 +1,18 @@ +name: glmnet +channels: + - conda-forge + - bioconda + - defaults +dependencies: + - r-base=4.3.3 + - r-dplyr=1.1.3 + - r-readr=2.1.5 + - r-vroom=1.6.5 + - r-glmnet=4.1 + - r-microbenchmark=1.4.10 + - r-splittools=1.0.1 + - r-glmnetutils=1.1.9 + - r-coefplot=1.2.8 + - r-pec=2023.04.12 + - r-rjson=0.2.21 + diff --git a/workflow/envs/make_metrics.yaml b/workflow/envs/make_metrics.yaml new file mode 100644 index 0000000..87140f0 --- /dev/null +++ b/workflow/envs/make_metrics.yaml @@ -0,0 +1,15 @@ +name: make_splits +channels: + - conda-forge + - bioconda + - pytorch + - defaults +dependencies: + - python=3.8.0 + - pandas=1.5.1 + - pycox=0.2.3 + - scikit-survival=0.22.2 + - pip + - pip: + - ../../ + #- https://github.com/BoevaLab/sparsesurv@paper diff --git a/workflow/envs/make_splits.yaml b/workflow/envs/make_splits.yaml new file mode 100644 index 0000000..ab016c4 --- /dev/null +++ b/workflow/envs/make_splits.yaml @@ -0,0 +1,9 @@ +name: make_splits +channels: + - conda-forge + - bioconda + - defaults +dependencies: + - python=3.12.3 + - pandas=2.2.2 + - scikit-learn=1.5.0 diff --git a/workflow/envs/plots_and_tables.yaml b/workflow/envs/plots_and_tables.yaml new file mode 100644 index 0000000..d7b2d98 --- /dev/null +++ b/workflow/envs/plots_and_tables.yaml @@ -0,0 +1,18 @@ +name: preprocess +channels: + - conda-forge + - bioconda + - defaults +dependencies: + - r-base=4.3.3 + - r-dplyr=1.1.3 + - r-readr=2.1.5 + - r-tidyr=1.3.1 + - r-stringr=1.5.1 + - r-vroom=1.6.5 + - r-readxl=1.4.3 + - r-ggplot2=3.4.4 + - r-cowplot=1.1.3 + - r-ggsignif=0.6.4 + - r-rjson=0.2.21 + - r-svglite=2.1.3 diff --git a/workflow/envs/preprocess.yaml b/workflow/envs/preprocess.yaml new file mode 100644 index 0000000..8224ade --- /dev/null +++ b/workflow/envs/preprocess.yaml @@ -0,0 +1,13 @@ +name: preprocess +channels: + - conda-forge + - bioconda + - defaults +dependencies: + - r-base=4.3.3 + - r-dplyr=1.1.3 + - r-readr=2.1.5 + - r-tidyr=1.3.1 + - r-stringr=1.5.1 + - r-vroom=1.6.5 + - r-readxl=1.4.3 diff --git a/workflow/envs/sparsesurv.yaml b/workflow/envs/sparsesurv.yaml new file mode 100644 index 0000000..3352e20 --- /dev/null +++ b/workflow/envs/sparsesurv.yaml @@ -0,0 +1,24 @@ +name: sparsesurv +channels: + - conda-forge + - bioconda + - pytorch + - sebp + - defaults +dependencies: + - python=3.11.0 + - pytorch=2.1.2 + - scikit-learn=1.3.0 + - pandas=1.5.3 + - skorch=0.15.0 + - numba=0.59.1 + - scikit-survival=0.22.2 + #- gcc_linux-64=13.2.0 + #- gxx_linux-64=13.2.0 + #- libgcc-ng=13.2.0 + - cxx-compiler=1.7.0 + - pip + - pip: + #- git+https://github.com/BoevaLab/sparsesurv@paper + - celer==0.7.3 + - ../../ diff --git a/paper/scripts/py/make_metrics.py b/workflow/scripts/py/make_metrics.py similarity index 91% rename from paper/scripts/py/make_metrics.py rename to workflow/scripts/py/make_metrics.py index 7b6b29f..052555b 100644 --- a/paper/scripts/py/make_metrics.py +++ b/workflow/scripts/py/make_metrics.py @@ -1,23 +1,21 @@ -import json -import os +import sys -import numpy as np -import pandas as pd -from pycox.evaluation import EvalSurv -from sksurv.metrics import concordance_index_censored, concordance_index_ipcw -from sksurv.util import Surv +with open(snakemake.log[0], "w") as f: + sys.stderr = sys.stdout = f + import json -def main() -> int: - with open("./config.json") as f: + import numpy as np + import pandas as pd + from pycox.evaluation import EvalSurv + from sksurv.metrics import concordance_index_censored, concordance_index_ipcw + from sksurv.util import Surv + + with open(snakemake.params["config_path"]) as f: config = json.load(f) np.random.seed(config["random_state"]) sksurv_converter = Surv() transform_survival = sksurv_converter.from_arrays - - splits_path = os.path.join(".", "data", "splits", "TCGA") - data_path = os.path.join(".", "data", "processed", "TCGA") - os.makedirs(splits_path, exist_ok=True) model = [] pc = [] score = [] @@ -31,12 +29,11 @@ def main() -> int: for cancer in config["datasets"]: print(f"Starting: {cancer}") - data_path = f"./data/processed/TCGA/{cancer}_data_preprocessed.csv" - df = pd.read_csv(data_path) + df = pd.read_csv(f"results/preprocess_data/{cancer}.csv").iloc[:, 1:] time = df["OS_days"].values event = df["OS"].values - test_splits = pd.read_csv(f"./data/splits/TCGA/{cancer}_test_splits.csv") - train_splits = pd.read_csv(f"./data/splits/TCGA/{cancer}_train_splits.csv") + test_splits = pd.read_csv(f"results/make_splits/{cancer}_test_splits.csv") + train_splits = pd.read_csv(f"results/make_splits/{cancer}_train_splits.csv") for n_variables in [""]: if n_variables == "": n_variables_string = 0 @@ -57,7 +54,7 @@ def main() -> int: for score_function in ["linear_predictor"]: for model_type in model_list: lp = pd.read_csv( - f"./results/kd/{model_type}/{cancer}/eta_{score_function}_{lambda_type+n_variables}.csv" + f"results/kd/{model_type}/{cancer}/eta_{score_function}_{lambda_type+n_variables}.csv" ) for i in range(25): lp_split = lp.iloc[:, i].dropna().values @@ -112,7 +109,7 @@ def main() -> int: for score_function in ["vvh"]: for model_type in ["breslow"]: lp = pd.read_csv( - f"./results/non_kd/{model_type}/{cancer}/eta_tuned_l1_ratio_{score_function}_{lambda_type+n_variables}.csv" + f"results/non_kd/{model_type}/{cancer}/eta_tuned_l1_ratio_{score_function}_{lambda_type+n_variables}.csv" ).iloc[:, 1:] for i in range(25): lp_split = lp.iloc[:, i].dropna().values @@ -168,7 +165,7 @@ def main() -> int: for score_function in ["vvh"]: for model_type in ["breslow"]: lp = pd.read_csv( - f"./results/non_kd/{model_type}/{cancer}/eta_{score_function}_{lambda_type+n_variables}.csv" + f"results/non_kd/{model_type}/{cancer}/eta_{score_function}_{lambda_type+n_variables}.csv" ).iloc[:, 1:] for i in range(25): lp_split = lp.iloc[:, i].dropna().values @@ -230,7 +227,7 @@ def main() -> int: for i in range(25): surv = pd.read_csv( - f"./results/kd/{model_type}/{cancer}/survival_function_{score_function}_{lambda_type}_{str(i+1)+n_variables}.csv" + f"results/kd/{model_type}/{cancer}/survival_function_{score_function}_{lambda_type}_{str(i+1)+n_variables}.csv" ).T surv.index = surv.index.astype(float) test_split = ( @@ -269,7 +266,7 @@ def main() -> int: for i in range(25): surv = pd.read_csv( - f"./results/non_kd/{model_type}/{cancer}/survival_function_tuned_l1_ratio_{score_function}_{lambda_type}_{str(i+1)+n_variables}.csv" + f"results/non_kd/{model_type}/{cancer}/survival_function_tuned_l1_ratio_{score_function}_{lambda_type}_{str(i+1)+n_variables}.csv" ).T surv.index = surv.index.astype(float) test_split = ( @@ -308,7 +305,7 @@ def main() -> int: for i in range(25): surv = pd.read_csv( - f"./results/non_kd/{model_type}/{cancer}/survival_function_{score_function}_{lambda_type}_{str(i+1)+n_variables}.csv" + f"results/non_kd/{model_type}/{cancer}/survival_function_{score_function}_{lambda_type}_{str(i+1)+n_variables}.csv" ).T surv.index = surv.index.astype(float) test_split = ( @@ -345,7 +342,7 @@ def main() -> int: for i in range(25): surv = pd.read_csv( - f"./results/kd/{model_type}/{cancer}/survival_function_teacher_{i+1}.csv" + f"results/kd/{model_type}/{cancer}/survival_function_teacher_{i+1}.csv" ).T surv.index = surv.index.astype(float) test_split = test_splits.iloc[i, :].dropna().values.astype(int) @@ -376,7 +373,7 @@ def main() -> int: for model_type in ["breslow", "cox_nnet"]: for i in range(25): surv = pd.read_csv( - f"./results/kd/{model_type}/{cancer}/path/survival_function_{path_num+1}_alpha_{i+1}.csv" + f"results/kd/{model_type}/{cancer}/path/survival_function_{path_num+1}_alpha_{i+1}.csv" ).T surv.index = surv.index.astype(float) @@ -408,7 +405,7 @@ def main() -> int: for i in range(25): surv = pd.read_csv( - f"./results/non_kd/{model_type}/{cancer}/path/survival_function_{path_num+1}_alpha_{i+1}.csv" + f"results/non_kd/{model_type}/{cancer}/path/survival_function_{path_num+1}_alpha_{i+1}.csv" ).T surv.index = surv.index.astype(float) test_split = test_splits.iloc[i, :].dropna().values.astype(int) @@ -447,9 +444,4 @@ def main() -> int: "n_variables": variables, "tuned": tuned, } - ).to_csv("./results/metrics/metrics_overall.csv", index=False) - return 0 - - -if __name__ == "__main__": - main() + ).to_csv("results/metrics/metrics_overall.csv", index=False) diff --git a/paper/scripts/py/make_metrics_cved.py b/workflow/scripts/py/make_metrics_cved.py similarity index 92% rename from paper/scripts/py/make_metrics_cved.py rename to workflow/scripts/py/make_metrics_cved.py index 91bad37..6023c4a 100644 --- a/paper/scripts/py/make_metrics_cved.py +++ b/workflow/scripts/py/make_metrics_cved.py @@ -1,22 +1,21 @@ -import json -import os +import sys -import numpy as np -import pandas as pd -from sksurv.metrics import concordance_index_censored, concordance_index_ipcw -from sksurv.util import Surv +with open(snakemake.log[0], "w") as f: + sys.stderr = sys.stdout = f + import json -def main() -> int: - with open("./config.json") as f: + import numpy as np + import pandas as pd + from sksurv.metrics import concordance_index_censored, concordance_index_ipcw + from sksurv.util import Surv + + with open(snakemake.params["config_path"]) as f: config = json.load(f) np.random.seed(config["random_state"]) sksurv_converter = Surv() transform_survival = sksurv_converter.from_arrays - splits_path = os.path.join(".", "data", "splits", "TCGA") - data_path = os.path.join(".", "data", "processed", "TCGA") - os.makedirs(splits_path, exist_ok=True) model = [] pc = [] score = [] @@ -30,12 +29,11 @@ def main() -> int: for cancer in config["datasets"]: print(f"Starting: {cancer}") - data_path = f"./data/processed/TCGA/{cancer}_data_preprocessed.csv" - df = pd.read_csv(data_path) + df = pd.read_csv(f"results/preprocess_data/{cancer}.csv").iloc[:, 1:] time = df["OS_days"].values event = df["OS"].values - test_splits = pd.read_csv(f"./data/splits/TCGA/{cancer}_test_splits.csv") - train_splits = pd.read_csv(f"./data/splits/TCGA/{cancer}_train_splits.csv") + test_splits = pd.read_csv(f"results/make_splits/{cancer}_test_splits.csv") + train_splits = pd.read_csv(f"results/make_splits/{cancer}_train_splits.csv") for n_variables in [""]: if n_variables == "": n_variables_string = 0 @@ -56,7 +54,7 @@ def main() -> int: for score_function in ["linear_predictor"]: for model_type in model_list: lp = pd.read_csv( - f"./results/kd/{model_type}/{cancer}/eta_{score_function}_{lambda_type+n_variables}.csv" + f"results/kd/{model_type}/{cancer}/eta_{score_function}_{lambda_type+n_variables}.csv" ) for i in range(125): lp_split = lp.iloc[:, i].dropna().values @@ -111,7 +109,7 @@ def main() -> int: for score_function in ["vvh"]: for model_type in ["breslow"]: lp = pd.read_csv( - f"./results/non_kd/{model_type}/{cancer}/eta_tuned_l1_ratio_{score_function}_{lambda_type+n_variables}.csv" + f"results/non_kd/{model_type}/{cancer}/eta_tuned_l1_ratio_{score_function}_{lambda_type+n_variables}.csv" ).iloc[:, 1:] for i in range(125): lp_split = lp.iloc[:, i].dropna().values @@ -167,7 +165,7 @@ def main() -> int: for score_function in ["vvh"]: for model_type in ["breslow"]: lp = pd.read_csv( - f"./results/non_kd/{model_type}/{cancer}/eta_{score_function}_{lambda_type+n_variables}.csv" + f"results/non_kd/{model_type}/{cancer}/eta_{score_function}_{lambda_type+n_variables}.csv" ).iloc[:, 1:] for i in range(125): lp_split = lp.iloc[:, i].dropna().values @@ -232,7 +230,7 @@ def main() -> int: "n_variables": variables, "tuned": tuned, } - ).to_csv("./results/metrics/metrics_overall_125_full.csv", index=False) + ).to_csv("results/metrics/metrics_overall_125_full.csv", index=False) model = [] pc = [] score = [] @@ -246,12 +244,11 @@ def main() -> int: for cancer in config["datasets"]: print(f"Starting: {cancer}") - data_path = f"./data/processed/TCGA/{cancer}_data_preprocessed.csv" - df = pd.read_csv(data_path) + df = pd.read_csv(f"results/preprocess_data/{cancer}.csv").iloc[:, 1:] time = df["OS_days"].values event = df["OS"].values - test_splits = pd.read_csv(f"./data/splits/TCGA/{cancer}_test_splits.csv") - train_splits = pd.read_csv(f"./data/splits/TCGA/{cancer}_train_splits.csv") + test_splits = pd.read_csv(f"results/make_splits/{cancer}_test_splits.csv") + train_splits = pd.read_csv(f"results/make_splits/{cancer}_train_splits.csv") for n_variables in [""]: if n_variables == "": n_variables_string = 0 @@ -272,7 +269,7 @@ def main() -> int: for score_function in ["linear_predictor"]: for model_type in model_list: lp = pd.read_csv( - f"./results/kd/{model_type}/{cancer}/eta_{score_function}_{lambda_type+n_variables}.csv" + f"results/kd/{model_type}/{cancer}/eta_{score_function}_{lambda_type+n_variables}.csv" ) for i in range(0, 125, 5): lp_split = np.concatenate( @@ -342,7 +339,7 @@ def main() -> int: for score_function in ["vvh"]: for model_type in ["breslow"]: lp = pd.read_csv( - f"./results/non_kd/{model_type}/{cancer}/eta_tuned_l1_ratio_{score_function}_{lambda_type+n_variables}.csv" + f"results/non_kd/{model_type}/{cancer}/eta_tuned_l1_ratio_{score_function}_{lambda_type+n_variables}.csv" ).iloc[:, 1:] for i in range(0, 125, 5): lp_split = np.concatenate( @@ -417,7 +414,7 @@ def main() -> int: for score_function in ["vvh"]: for model_type in ["breslow"]: lp = pd.read_csv( - f"./results/non_kd/{model_type}/{cancer}/eta_{score_function}_{lambda_type+n_variables}.csv" + f"results/non_kd/{model_type}/{cancer}/eta_{score_function}_{lambda_type+n_variables}.csv" ).iloc[:, 1:] for i in range(0, 125, 5): lp_split = np.concatenate( @@ -501,10 +498,4 @@ def main() -> int: "n_variables": variables, "tuned": tuned, } - ).to_csv("./results/metrics/metrics_overall_cved.csv", index=False) - - return 0 - - -if __name__ == "__main__": - main() + ).to_csv("results/metrics/metrics_overall_cved.csv", index=False) diff --git a/paper/scripts/py/make_metrics_teachers.py b/workflow/scripts/py/make_metrics_teachers.py similarity index 80% rename from paper/scripts/py/make_metrics_teachers.py rename to workflow/scripts/py/make_metrics_teachers.py index b54a090..524e912 100644 --- a/paper/scripts/py/make_metrics_teachers.py +++ b/workflow/scripts/py/make_metrics_teachers.py @@ -1,23 +1,20 @@ -import json -import os +import sys -import numpy as np -import pandas as pd -from pycox.evaluation import EvalSurv -from sksurv.metrics import concordance_index_censored, concordance_index_ipcw -from sksurv.util import Surv +with open(snakemake.log[0], "w") as f: + sys.stderr = sys.stdout = f + import json + import numpy as np + import pandas as pd + from pycox.evaluation import EvalSurv + from sksurv.metrics import concordance_index_censored, concordance_index_ipcw + from sksurv.util import Surv -def main() -> int: - with open("./config.json") as f: + with open(snakemake.params["config_path"]) as f: config = json.load(f) np.random.seed(config["random_state"]) sksurv_converter = Surv() transform_survival = sksurv_converter.from_arrays - - splits_path = os.path.join(".", "data", "splits", "TCGA") - data_path = os.path.join(".", "data", "processed", "TCGA") - os.makedirs(splits_path, exist_ok=True) model = [] pc = [] score = [] @@ -31,15 +28,14 @@ def main() -> int: for cancer in config["datasets"]: print(f"Starting: {cancer}") - data_path = f"./data/processed/TCGA/{cancer}_data_preprocessed.csv" - df = pd.read_csv(data_path) + df = pd.read_csv(f"results/preprocess_data/{cancer}.csv").iloc[:, 1:] time = df["OS_days"].values event = df["OS"].values - test_splits = pd.read_csv(f"./data/splits/TCGA/{cancer}_test_splits.csv") - train_splits = pd.read_csv(f"./data/splits/TCGA/{cancer}_train_splits.csv") + test_splits = pd.read_csv(f"results/make_splits/{cancer}_test_splits.csv") + train_splits = pd.read_csv(f"results/make_splits/{cancer}_train_splits.csv") for model_type in ["breslow", "cox_nnet"]: - lp = pd.read_csv(f"./results/kd/{model_type}/{cancer}/eta_teacher.csv") + lp = pd.read_csv(f"results/kd/{model_type}/{cancer}/eta_teacher.csv") for i in range(25): lp_split = lp.iloc[:, i].dropna().values test_split = test_splits.iloc[i, :].dropna().values.astype(int) @@ -85,7 +81,7 @@ def main() -> int: for i in range(25): surv = pd.read_csv( - f"./results/kd/{model_type}/{cancer}/survival_function_teacher_{i+1}.csv" + f"results/kd/{model_type}/{cancer}/survival_function_teacher_{i+1}.csv" ).T surv.index = surv.index.astype(float) test_split = test_splits.iloc[i, :].dropna().values.astype(int) @@ -125,9 +121,4 @@ def main() -> int: "n_variables": variables, "tuned": tuned, } - ).to_csv("./results/metrics/metrics_overall_teachers.csv", index=False) - return 0 - - -if __name__ == "__main__": - main() + ).to_csv("results/metrics/metrics_overall_teachers.csv", index=False) diff --git a/workflow/scripts/py/make_splits.py b/workflow/scripts/py/make_splits.py new file mode 100644 index 0000000..a9424a0 --- /dev/null +++ b/workflow/scripts/py/make_splits.py @@ -0,0 +1,50 @@ +import sys + + +def main( + data_path, + random_seed, + n_outer_repetitions, + n_outer_splits, + train_splits_output_path, + test_splits_output_path, +) -> int: + + import numpy as np + import pandas as pd + from sklearn.model_selection import RepeatedStratifiedKFold + + np.random.seed(random_seed) + data = pd.read_csv( + data_path, + low_memory=False, + ) + + # Exact column choice doesn't matter + # as this is only to create the splits anyway. + X = data[[i for i in data.columns if i not in ["OS_days", "OS"]]] + cv = RepeatedStratifiedKFold( + n_repeats=n_outer_repetitions, n_splits=n_outer_splits, random_state=random_seed + ) + splits = [i for i in cv.split(X, data["OS"])] + pd.DataFrame([i[0] for i in splits]).to_csv( + train_splits_output_path, + index=False, + ) + pd.DataFrame([i[1] for i in splits]).to_csv( + test_splits_output_path, + index=False, + ) + return 0 + + +with open(snakemake.log[0], "w") as f: + sys.stderr = sys.stdout = f + main( + data_path=snakemake.input["data_path"], + random_seed=snakemake.params["random_seed"], + n_outer_repetitions=snakemake.params["n_outer_repetitions"], + n_outer_splits=snakemake.params["n_outer_splits"], + train_splits_output_path=snakemake.output["train_splits_output_path"], + test_splits_output_path=snakemake.output["test_splits_output_path"], + ) diff --git a/workflow/scripts/py/make_table_S1.py b/workflow/scripts/py/make_table_S1.py new file mode 100644 index 0000000..6487f1c --- /dev/null +++ b/workflow/scripts/py/make_table_S1.py @@ -0,0 +1,73 @@ +import sys + +with open(snakemake.log[0], "w") as f: + sys.stderr = sys.stdout = f + + import json + + import numpy as np + import pandas as pd + + with open(snakemake.params["config_path"]) as f: + config = json.load(f) + + np.random.seed(config["random_state"]) + + cancer_type = config["datasets"] + tissue = [ + "Bladder", + "Breast", + "Head and neck", + "Kidney", + "Brain", + "Liver", + "Lung", + "Lung", + "Ovaries", + "Stomach", + ] + full_name = [ + "Bladder Urothelial Carcinoma", + "Breast invasive carcinoma", + "Head and neck squamous cell carcinoma", + "Kidney renal clear cell carcinoma", + "Brain lower grade glioma", + "Liver hepatocellular carcinoma", + "Lung adenocarcinoma", + "Lung squamous cell carcinoma", + "Ovarian serous cystadenocarcinoma", + "Stomach adenocarcinoma", + ] + p = [] + n = [] + event_ratio = [] + min_event_time = [] + max_event_time = [] + median_event_time = [] + + for cancer in config["datasets"]: + print(f"Starting: {cancer}") + train_splits = pd.read_csv(f"results/make_splits/{cancer}_train_splits.csv") + test_splits = pd.read_csv(f"results/make_splits/{cancer}_test_splits.csv") + data = pd.read_csv(f"results/preprocess_data/{cancer}.csv").iloc[:, 1:] + X_ = data.iloc[:, 3:] + p.append(X_.shape[1]) + n.append(X_.shape[0]) + event_ratio.append(np.mean(data["OS"].values)) + min_event_time.append(np.min(data["OS_days"].values)) + max_event_time.append(np.max(data["OS_days"].values)) + median_event_time.append(np.median(data["OS_days"].values)) + + pd.DataFrame( + { + "type": cancer_type, + "tissue": tissue, + "full_name": full_name, + "p": p, + "n": n, + "event_ratio": event_ratio, + "min_event_time": min_event_time, + "max_event_time": max_event_time, + "median_event_time": median_event_time, + } + ).to_csv("results/tables/table_S1.csv", index=False) diff --git a/workflow/scripts/py/run_path_sparsesurv.py b/workflow/scripts/py/run_path_sparsesurv.py new file mode 100644 index 0000000..f765b7c --- /dev/null +++ b/workflow/scripts/py/run_path_sparsesurv.py @@ -0,0 +1,255 @@ +import sys + +with open(snakemake.log[0], "w") as f: + sys.stderr = sys.stdout = f + + import json + + import celer + import numpy as np + import pandas as pd + import torch + from sklearn.decomposition import PCA + from sklearn.feature_selection import VarianceThreshold + from sklearn.metrics import make_scorer + from sklearn.model_selection import GridSearchCV, RandomizedSearchCV + from sklearn.pipeline import make_pipeline + from sklearn.preprocessing import StandardScaler + from skorch.callbacks import EarlyStopping + from sksurv.linear_model import CoxPHSurvivalAnalysis + + from sparsesurv.cv import KDPHElasticNetCV + from sparsesurv.loss import breslow_negative_likelihood, efron_negative_likelihood + from sparsesurv.neuralsurv.python.model.model import SKORCH_MODULE_FACTORY + from sparsesurv.neuralsurv.python.model.skorch_infra import FixSeed + from sparsesurv.neuralsurv.python.utils.factories import ( + CRITERION_FACTORY, + SKORCH_NET_FACTORY, + ) + from sparsesurv.neuralsurv.python.utils.misc_utils import ( + StratifiedSkorchSurvivalSplit, + StratifiedSurvivalKFold, + ) + from sparsesurv.utils import inverse_transform_survival, transform_survival + + with open(snakemake.params["config_path"]) as f: + config = json.load(f) + + def breslow_score_wrapper(y_true, y_pred): + time, event = inverse_transform_survival(y_true) + return np.negative( + breslow_negative_likelihood( + linear_predictor=np.squeeze(y_pred), time=time, event=event + ) + ) + + def efron_score_wrapper(y_true, y_pred): + time, event = inverse_transform_survival(y_true) + return np.negative( + efron_negative_likelihood(linear_predictor=y_pred, time=time, event=event) + ) + + SCORE_FACTORY = {"breslow": breslow_score_wrapper, "efron": efron_score_wrapper} + + np.random.seed(config["random_state"]) + g = np.random.default_rng(config.get("random_state")) + model_pipe = make_pipeline( + VarianceThreshold(), + StandardScaler(), + ) + en = celer.ElasticNet( + l1_ratio=config["l1_ratio"], + fit_intercept=False, + ) + + for tie_correction in ["breslow"]: + pc_pipe = GridSearchCV( + estimator=make_pipeline( + VarianceThreshold(), + StandardScaler(), + PCA( + n_components=config[f"pc_n_components"], + random_state=config["random_state"], + ), + CoxPHSurvivalAnalysis(ties=tie_correction), + ), + param_grid={"pca__n_components": config["pc_n_components_tuned"]}, + n_jobs=config["n_jobs"], + scoring=make_scorer(SCORE_FACTORY[tie_correction]), + cv=StratifiedSurvivalKFold( + n_splits=config["n_inner_cv"], + shuffle=config["shuffle_cv"], + random_state=config["random_state"], + ), + ) + + for cancer in config["datasets"]: + sparsity = {} + + train_splits = pd.read_csv(f"results/make_splits/{cancer}_train_splits.csv") + test_splits = pd.read_csv(f"results/make_splits/{cancer}_test_splits.csv") + data = pd.read_csv(f"results/preprocess_data/{cancer}.csv").iloc[:, 1:] + X_ = data.iloc[:, 3:] + y_ = transform_survival( + time=data["OS_days"].values, event=data["OS"].values + ) + for split in range(25): + train_ix = train_splits.iloc[split, :].dropna().to_numpy().astype(int) + test_ix = test_splits.iloc[split, :].dropna().to_numpy().astype(int) + X_train = X_.iloc[train_ix, :].copy().reset_index(drop=True).to_numpy() + y_train = y_[train_ix].copy() + y_test = y_[test_ix].copy() + X_test = X_.iloc[test_ix, :].copy().reset_index(drop=True).to_numpy() + + pc_pipe.fit(X_train, y_train) + path_results = en.path( + X=model_pipe.fit_transform(X_train), + y=pc_pipe.predict(X_train), + l1_ratio=config["l1_ratio"], + eps=config["eps"], + n_alphas=config["n_alphas"], + alphas=None, + ) + + for z in range(config["n_alphas"]): + path_coef = path_results[1][:, z] + if z == 0: + sparsity[split] = [] + sparsity[split].append(np.sum(path_coef != 0.0)) + helper = KDPHElasticNetCV( + tie_correction="efron", + seed=np.random.RandomState(config["random_state"]), + ) + helper.coef_ = path_coef + ix_sort = np.argsort(y_train["time"]) + helper.train_time_ = y_train["time"][ix_sort] + helper.train_event_ = y_train["event"][ix_sort] + helper.train_eta_ = helper.predict(model_pipe.transform(X_train))[ + ix_sort + ] + surv = helper.predict_survival_function( + model_pipe.transform(X_test), np.unique(y_test["time"]) + ) + surv.to_csv( + f"results/kd/{tie_correction}/{cancer}/path/survival_function_{z+1}_alpha_{split+1}.csv", + index=False, + ) + + pd.DataFrame(sparsity).to_csv( + f"results/kd/{tie_correction}/{cancer}/path/sparsity.csv", + index=False, + ) + + for cancer in config["datasets"]: + sparsity = {} + train_splits = pd.read_csv(f"results/make_splits/{cancer}_train_splits.csv") + test_splits = pd.read_csv(f"results/make_splits/{cancer}_test_splits.csv") + data = pd.read_csv(f"results/preprocess_data/{cancer}.csv").iloc[:, 1:] + X_ = data.iloc[:, 3:] + y_ = transform_survival(time=data["OS_days"].values, event=data["OS"].values) + pc_pipe = RandomizedSearchCV( + estimator=make_pipeline( + StandardScaler(), + SKORCH_NET_FACTORY["cox"]( + module=SKORCH_MODULE_FACTORY["cox"], + criterion=CRITERION_FACTORY["cox"], + module__fusion_method="early", + module__blocks=[[i for i in range(X_.shape[1])]], + iterator_train__shuffle=True, + optimizer=torch.optim.AdamW, + max_epochs=config["max_epochs"], + verbose=False, + train_split=StratifiedSkorchSurvivalSplit( + config["validation_set_neural"], + stratified=config["stratify_cv"], + random_state=config.get("random_state"), + ), + callbacks=[ + ( + "es", + EarlyStopping( + monitor="valid_loss", + patience=config["early_stopping_patience"], + load_best=True, + ), + ), + ("seed", FixSeed(generator=g)), + ], + module__activation=torch.nn.ReLU, + ), + ), + param_distributions={ + "coxphneuralnet__lr": config["tune_lr"], + "coxphneuralnet__optimizer__weight_decay": config["tune_weight_decay"], + "coxphneuralnet__module__modality_hidden_layer_size": config[ + "tune_modality_hidden_layer_size" + ], + "coxphneuralnet__module__modality_hidden_layers": config[ + "tune_modality_hidden_layers" + ], + "coxphneuralnet__module__p_dropout": config["tune_p_dropout"], + "coxphneuralnet__batch_size": config["tune_batch_size"], + }, + n_jobs=config["n_jobs"], + random_state=config["random_state"], + scoring=make_scorer(breslow_score_wrapper), + cv=StratifiedSurvivalKFold( + n_splits=config["n_inner_cv"], + shuffle=config["shuffle_cv"], + random_state=config["random_state"], + ), + error_score=config["error_score"], + verbose=False, + n_iter=config["random_search_n_iter"], + ) + for split in range(25): + train_ix = train_splits.iloc[split, :].dropna().to_numpy().astype(int) + test_ix = test_splits.iloc[split, :].dropna().to_numpy().astype(int) + X_train = ( + X_.iloc[train_ix, :].copy().reset_index(drop=True).to_numpy(np.float32) + ) + y_train = y_[train_ix].copy() + y_test = y_[test_ix].copy() + X_test = ( + X_.iloc[test_ix, :].copy().reset_index(drop=True).to_numpy(np.float32) + ) + + pc_pipe.fit(X_train, y_train) + path_results = en.path( + X=model_pipe.fit_transform(X_train), + y=np.squeeze(pc_pipe.predict(X_train)), + l1_ratio=config["l1_ratio"], + eps=config["eps"], + n_alphas=config["n_alphas"], + alphas=None, + ) + + for z in range(config["n_alphas"]): + path_coef = path_results[1][:, z] + if z == 0: + sparsity[split] = [] + sparsity[split].append(np.sum(path_coef != 0.0)) + helper = KDPHElasticNetCV( + tie_correction="breslow", + seed=np.random.RandomState(config["random_state"]), + ) + helper.coef_ = path_coef + ix_sort = np.argsort(y_train["time"]) + helper.train_time_ = y_train["time"][ix_sort] + helper.train_event_ = y_train["event"][ix_sort] + helper.train_eta_ = helper.predict(model_pipe.transform(X_train))[ + ix_sort + ] + + surv = helper.predict_survival_function( + model_pipe.transform(X_test), np.unique(y_test["time"]) + ) + surv.to_csv( + f"results/kd/cox_nnet/{cancer}/path/survival_function_{z+1}_alpha_{split+1}.csv", + index=False, + ) + + pd.DataFrame(sparsity).to_csv( + f"results/kd/cox_nnet/{cancer}/path/sparsity.csv", + index=False, + ) diff --git a/workflow/scripts/py/run_sparsesurv.py b/workflow/scripts/py/run_sparsesurv.py new file mode 100644 index 0000000..d128024 --- /dev/null +++ b/workflow/scripts/py/run_sparsesurv.py @@ -0,0 +1,351 @@ +import sys + +with open(snakemake.log[0], "w") as f: + sys.stderr = sys.stdout = f + + import json + + import numpy as np + import pandas as pd + import torch + from sklearn.decomposition import PCA + from sklearn.feature_selection import VarianceThreshold + from sklearn.metrics import make_scorer + from sklearn.model_selection import GridSearchCV, RandomizedSearchCV + from sklearn.pipeline import make_pipeline + from sklearn.preprocessing import StandardScaler + from skorch.callbacks import EarlyStopping + from sksurv.linear_model import CoxPHSurvivalAnalysis + + from sparsesurv._base import KDSurv + from sparsesurv.cv import KDPHElasticNetCV + from sparsesurv.loss import breslow_negative_likelihood, efron_negative_likelihood + from sparsesurv.neuralsurv.python.model.model import SKORCH_MODULE_FACTORY + from sparsesurv.neuralsurv.python.model.skorch_infra import FixSeed + from sparsesurv.neuralsurv.python.utils.factories import ( + CRITERION_FACTORY, + SKORCH_NET_FACTORY, + ) + from sparsesurv.neuralsurv.python.utils.misc_utils import ( + StratifiedSkorchSurvivalSplit, + StratifiedSurvivalKFold, + ) + from sparsesurv.utils import inverse_transform_survival, transform_survival + + with open(snakemake.params["config_path"]) as f: + config = json.load(f) + + def breslow_score_wrapper(y_true, y_pred): + time, event = inverse_transform_survival(y_true) + return np.negative( + breslow_negative_likelihood( + linear_predictor=np.squeeze(y_pred), time=time, event=event + ) + ) + + def efron_score_wrapper(y_true, y_pred): + time, event = inverse_transform_survival(y_true) + return np.negative( + efron_negative_likelihood( + linear_predictor=np.squeeze(y_pred), time=time, event=event + ) + ) + + np.random.seed(config["random_state"]) + g = np.random.default_rng(config.get("random_state")) + + for tune_l1_ratio in [False]: + for tie_correction in ["breslow"]: + for score_type in ["min", "pcvl"]: + for score in ["linear_predictor"]: + results = {} + failures = {} + sparsity = {} + pipe = KDSurv( + teacher=GridSearchCV( + estimator=make_pipeline( + VarianceThreshold(), + StandardScaler(), + PCA(n_components=config["pc_n_components_tuned"], random_state=config["random_state"]), + CoxPHSurvivalAnalysis(ties=tie_correction), + ), + param_grid={ + "pca__n_components": config["pc_n_components_tuned"] + }, + n_jobs=config["n_jobs"], + scoring=make_scorer(efron_score_wrapper), + cv=StratifiedSurvivalKFold( + n_splits=config["n_inner_cv"], + shuffle=config["shuffle_cv"], + random_state=config["random_state"], + ), + ), + student=make_pipeline( + VarianceThreshold(), + StandardScaler(), + KDPHElasticNetCV( + tie_correction=tie_correction, + l1_ratio=config[ + f"l1_ratio{'_tuned' if tune_l1_ratio else ''}" + ], + eps=config["eps"], + n_alphas=config["n_alphas"], + cv=config["n_inner_cv"], + stratify_cv=config["stratify_cv"], + seed=np.random.RandomState(config["random_state"]), + shuffle_cv=config["shuffle_cv"], + n_jobs=config["n_jobs"], + cv_score_method=score, + alpha_type=score_type, + ), + ), + ) + + for cancer in config["datasets"]: + + train_splits = pd.read_csv( + f"results/make_splits/{cancer}_train_splits.csv" + ) + test_splits = pd.read_csv( + f"results/make_splits/{cancer}_test_splits.csv" + ) + data = pd.read_csv( + f"results/preprocess_data/{cancer}.csv" + ).iloc[:, 1:] + X_ = data.iloc[:, 3:] + y_ = transform_survival( + time=data["OS_days"].values, event=data["OS"].values + ) + for split in range(125 if score_type == "min" else 25): + train_ix = ( + train_splits.iloc[split, :] + .dropna() + .to_numpy() + .astype(int) + ) + test_ix = ( + test_splits.iloc[split, :] + .dropna() + .to_numpy() + .astype(int) + ) + X_train = ( + X_.iloc[train_ix, :] + .copy() + .reset_index(drop=True) + .to_numpy() + ) + y_train = y_[train_ix].copy() + y_test = y_[test_ix].copy() + X_test = ( + X_.iloc[test_ix, :] + .copy() + .reset_index(drop=True) + .to_numpy() + ) + if split == 0: + results[cancer] = {} + sparsity[cancer] = {} + failures[cancer] = [0] + try: + pipe.fit(X_train, y_train) + sparsity[cancer][split] = np.sum( + pipe.student[-1].coef_ != 0 + ) + results[cancer][split] = pipe.predict(X_test) + surv = pipe.predict_survival_function( + X_test, np.unique(y_test["time"]) + ) + surv.to_csv( + f"results/kd/{tie_correction}/{cancer}/survival_function{'_tuned_l1' if tune_l1_ratio else ''}_{score}_{score_type}_{split+1}.csv", + index=False, + ) + except ValueError as e: + failures[cancer][0] += 1 + results[cancer][split] = np.zeros(test_ix.shape[0]) + sparsity[cancer][split] = 0 + + pd.concat( + [ + pd.DataFrame(results[cancer][i]) + for i in range(125 if score_type == "min" else 25) + ], + axis=1, + ).to_csv( + f"results/kd/{tie_correction}/{cancer}/eta{'_tuned_l1' if tune_l1_ratio else ''}_{score}_{score_type}.csv", + index=False, + ) + + pd.DataFrame(sparsity).to_csv( + f"results/kd/{tie_correction}/sparsity{'_tuned_l1' if tune_l1_ratio else ''}_{score}_{score_type}.csv", + index=False, + ) + pd.DataFrame(failures).to_csv( + f"results/kd/{tie_correction}/failures{'_tuned_l1' if tune_l1_ratio else ''}_{score}_{score_type}.csv", + index=False, + ) + + for tune_l1_ratio in [False]: + for score_type in ["min"]: + results = {} + failures = {} + sparsity = {} + for cancer in config["datasets"]: + for score in ["linear_predictor"]: + + train_splits = pd.read_csv( + f"results/make_splits/{cancer}_train_splits.csv" + ) + test_splits = pd.read_csv( + f"results/make_splits/{cancer}_test_splits.csv" + ) + data = pd.read_csv(f"results/preprocess_data/{cancer}.csv").iloc[ + :, 1: + ] + X_ = data.iloc[:, 3:] + y_ = transform_survival( + time=data["OS_days"].values, event=data["OS"].values + ) + pipe = KDSurv( + teacher=RandomizedSearchCV( + estimator=make_pipeline( + StandardScaler(), + SKORCH_NET_FACTORY["cox"]( + module=SKORCH_MODULE_FACTORY["cox"], + criterion=CRITERION_FACTORY["cox"], + module__fusion_method="early", + module__blocks=[[i for i in range(X_.shape[1])]], + iterator_train__shuffle=True, + optimizer=torch.optim.AdamW, + max_epochs=config["max_epochs"], + verbose=False, + train_split=StratifiedSkorchSurvivalSplit( + config["validation_set_neural"], + stratified=config["stratify_cv"], + random_state=config.get("random_state"), + ), + callbacks=[ + ( + "es", + EarlyStopping( + monitor="valid_loss", + patience=config[ + "early_stopping_patience" + ], + load_best=True, + ), + ), + ("seed", FixSeed(generator=g)), + ], + module__activation=torch.nn.ReLU, + ), + ), + param_distributions=[ + { + "coxphneuralnet__lr": config["tune_lr"], + "coxphneuralnet__optimizer__weight_decay": config[ + "tune_weight_decay" + ], + "coxphneuralnet__module__modality_hidden_layer_size": config[ + "tune_modality_hidden_layer_size" + ], + "coxphneuralnet__module__modality_hidden_layers": config[ + "tune_modality_hidden_layers" + ], + "coxphneuralnet__module__p_dropout": config[ + "tune_p_dropout" + ], + "coxphneuralnet__batch_size": config[ + "tune_batch_size" + ], + } + ], + n_jobs=config["n_jobs"], + scoring=make_scorer(breslow_score_wrapper), + cv=StratifiedSurvivalKFold( + n_splits=config["n_inner_cv"], + shuffle=config["shuffle_cv"], + random_state=config["random_state"], + ), + error_score=config["error_score"], + verbose=False, + n_iter=config["random_search_n_iter"], + random_state=config["random_state"], + ), + student=make_pipeline( + VarianceThreshold(), + StandardScaler(), + KDPHElasticNetCV( + tie_correction="breslow", + l1_ratio=config[ + f"l1_ratio{'_tuned' if tune_l1_ratio else ''}" + ], + eps=config["eps"], + n_alphas=config["n_alphas"], + cv=config["n_inner_cv"], + stratify_cv=config["stratify_cv"], + seed=np.random.RandomState(config["random_state"]), + shuffle_cv=config["shuffle_cv"], + cv_score_method=score, + n_jobs=config["n_jobs"], + alpha_type=score_type, + ), + ), + ) + for split in range(25): + train_ix = ( + train_splits.iloc[split, :].dropna().to_numpy().astype(int) + ) + test_ix = ( + test_splits.iloc[split, :].dropna().to_numpy().astype(int) + ) + X_train = ( + X_.iloc[train_ix, :] + .copy() + .reset_index(drop=True) + .to_numpy(np.float32) + ) + y_train = y_[train_ix].copy() + y_test = y_[test_ix].copy() + X_test = ( + X_.iloc[test_ix, :] + .copy() + .reset_index(drop=True) + .to_numpy(np.float32) + ) + if split == 0: + results[cancer] = {} + sparsity[cancer] = {} + failures[cancer] = [0] + try: + pipe.fit(X_train, y_train) + sparsity[cancer][split] = np.sum( + pipe.student[-1].coef_ != 0 + ) + results[cancer][split] = pipe.predict(X_test) + + surv = pipe.predict_survival_function( + X_test.astype(np.float32), np.unique(y_test["time"]) + ) + + surv.to_csv( + f"results/kd/cox_nnet/{cancer}/survival_function{'_tuned_l1' if tune_l1_ratio else ''}_{score}_{score_type}_{split+1}.csv", + index=False, + ) + except ValueError as e: + raise e + pd.concat( + [pd.DataFrame(results[cancer][i]) for i in range(25)], + axis=1, + ).to_csv( + f"results/kd/cox_nnet/{cancer}/eta{'_tuned_l1' if tune_l1_ratio else ''}_{score}_{score_type}.csv", + index=False, + ) + pd.DataFrame(sparsity).to_csv( + f"results/kd/cox_nnet/sparsity{'_tuned_l1' if tune_l1_ratio else ''}_{score}_{score_type}.csv", + index=False, + ) + pd.DataFrame(failures).to_csv( + f"results/kd/cox_nnet/failures{'_tuned_l1' if tune_l1_ratio else ''}_{score}_{score_type}.csv", + index=False, + ) diff --git a/workflow/scripts/py/run_teachers.py b/workflow/scripts/py/run_teachers.py new file mode 100644 index 0000000..10aeb15 --- /dev/null +++ b/workflow/scripts/py/run_teachers.py @@ -0,0 +1,275 @@ +import sys + +with open(snakemake.log[0], "w") as f: + sys.stderr = sys.stdout = f + import json + + import numpy as np + import pandas as pd + import torch + from sklearn.decomposition import PCA + from sklearn.feature_selection import VarianceThreshold + from sklearn.metrics import make_scorer + from sklearn.model_selection import GridSearchCV, RandomizedSearchCV + from sklearn.pipeline import make_pipeline + from sklearn.preprocessing import StandardScaler + from skorch.callbacks import EarlyStopping + from sksurv.linear_model import CoxPHSurvivalAnalysis + from sksurv.linear_model.coxph import BreslowEstimator + from sparsesurv.loss import breslow_negative_likelihood, efron_negative_likelihood + from sparsesurv.neuralsurv.python.model.model import SKORCH_MODULE_FACTORY + from sparsesurv.neuralsurv.python.model.skorch_infra import FixSeed + from sparsesurv.neuralsurv.python.utils.factories import ( + CRITERION_FACTORY, + SKORCH_NET_FACTORY, + ) + from sparsesurv.neuralsurv.python.utils.misc_utils import ( + StratifiedSkorchSurvivalSplit, + StratifiedSurvivalKFold, + ) + from sparsesurv.utils import inverse_transform_survival, transform_survival + + def breslow_score_wrapper(y_true, y_pred): + time, event = inverse_transform_survival(y_true) + return np.negative( + breslow_negative_likelihood( + linear_predictor=np.squeeze(y_pred), time=time, event=event + ) + ) + + def efron_score_wrapper(y_true, y_pred): + time, event = inverse_transform_survival(y_true) + return np.negative( + efron_negative_likelihood(linear_predictor=y_pred, time=time, event=event) + ) + + with open(snakemake.params["config_path"]) as f: + config = json.load(f) + + np.random.seed(config["random_state"]) + g = np.random.default_rng(config.get("random_state")) + + results = {} + for tie_correction in ["breslow"]: + teacher = GridSearchCV( + estimator=make_pipeline( + VarianceThreshold(), + StandardScaler(), + PCA( + n_components=config["pc_n_components"], + random_state=config["random_state"], + ), + CoxPHSurvivalAnalysis(ties=tie_correction), + ), + param_grid={"pca__n_components": config["pc_n_components_tuned"]}, + n_jobs=config["n_jobs"], + scoring=make_scorer(efron_score_wrapper), + cv=StratifiedSurvivalKFold( + n_splits=config["n_inner_cv"], + shuffle=config["shuffle_cv"], + random_state=config["random_state"], + ), + ) + + for cancer in config["datasets"]: + results[cancer] = {} + train_splits = pd.read_csv(f"results/make_splits/{cancer}_train_splits.csv") + test_splits = pd.read_csv(f"results/make_splits/{cancer}_test_splits.csv") + data = pd.read_csv(f"results/preprocess_data/{cancer}.csv").iloc[:, 1:] + X_ = data.iloc[:, 3:] + y_ = transform_survival( + time=data["OS_days"].values, event=data["OS"].values + ) + for split in range(25): + train_ix = train_splits.iloc[split, :].dropna().to_numpy().astype(int) + test_ix = test_splits.iloc[split, :].dropna().to_numpy().astype(int) + X_train = X_.iloc[train_ix, :].copy().reset_index(drop=True).to_numpy() + y_train = y_[train_ix].copy() + y_test = y_[test_ix].copy() + X_test = X_.iloc[test_ix, :].copy().reset_index(drop=True).to_numpy() + + teacher.fit(X_train, y_train) + results[cancer][split] = teacher.predict(X_test) + ( + cumulative_baseline_hazards_times, + cumulative_baseline_hazards, + ) = ( + teacher.best_estimator_[3].cum_baseline_hazard_.x, + teacher.best_estimator_[3].cum_baseline_hazard_.y, + ) + cumulative_baseline_hazards = np.concatenate( + [np.array([0.0]), cumulative_baseline_hazards] + ) + cumulative_baseline_hazards_times: np.array = np.concatenate( + [np.array([0.0]), cumulative_baseline_hazards_times] + ) + cumulative_baseline_hazards: np.array = np.tile( + A=cumulative_baseline_hazards[ + np.digitize( + x=np.unique(y_test["time"]), + bins=cumulative_baseline_hazards_times, + right=False, + ) + - 1 + ], + reps=X_test.shape[0], + ).reshape((X_test.shape[0], np.unique(y_test["time"]).shape[0])) + log_hazards: np.array = ( + np.tile( + A=teacher.predict(X_test), + reps=np.unique(y_test["time"]).shape[0], + ) + .reshape((np.unique(y_test["time"]).shape[0], X_test.shape[0])) + .T + ) + surv: pd.DataFrame = np.exp( + -pd.DataFrame( + cumulative_baseline_hazards * np.exp(log_hazards), + columns=np.unique(y_test["time"]), + ) + ) + surv.to_csv( + f"results/kd/{tie_correction}/{cancer}/survival_function_teacher_{split+1}.csv", + index=False, + ) + pd.concat( + [pd.DataFrame(results[cancer][i]) for i in range(25)], + axis=1, + ).to_csv( + f"results/kd/{tie_correction}/{cancer}/eta_teacher.csv", + index=False, + ) + + results = {} + for cancer in config["datasets"]: + results[cancer] = {} + train_splits = pd.read_csv(f"results/make_splits/{cancer}_train_splits.csv") + test_splits = pd.read_csv(f"results/make_splits/{cancer}_test_splits.csv") + data = pd.read_csv(f"results/preprocess_data/{cancer}.csv").iloc[:, 1:] + X_ = data.iloc[:, 3:] + y_ = transform_survival(time=data["OS_days"].values, event=data["OS"].values) + teacher_cox_nnet = RandomizedSearchCV( + estimator=make_pipeline( + StandardScaler(), + SKORCH_NET_FACTORY["cox"]( + module=SKORCH_MODULE_FACTORY["cox"], + criterion=CRITERION_FACTORY["cox"], + module__fusion_method="early", + module__blocks=[[i for i in range(X_.shape[1])]], + iterator_train__shuffle=True, + optimizer=torch.optim.AdamW, + max_epochs=config["max_epochs"], + verbose=False, + train_split=StratifiedSkorchSurvivalSplit( + config["validation_set_neural"], + stratified=config["stratify_cv"], + random_state=config.get("random_state"), + ), + callbacks=[ + ( + "es", + EarlyStopping( + monitor="valid_loss", + patience=config["early_stopping_patience"], + load_best=True, + ), + ), + ("seed", FixSeed(generator=g)), + ], + module__activation=torch.nn.ReLU, + ), + ), + param_distributions={ + "coxphneuralnet__lr": config["tune_lr"], + "coxphneuralnet__optimizer__weight_decay": config["tune_weight_decay"], + "coxphneuralnet__module__modality_hidden_layer_size": config[ + "tune_modality_hidden_layer_size" + ], + "coxphneuralnet__module__modality_hidden_layers": config[ + "tune_modality_hidden_layers" + ], + "coxphneuralnet__module__p_dropout": config["tune_p_dropout"], + "coxphneuralnet__batch_size": config["tune_batch_size"], + }, + n_jobs=config["n_jobs"], + random_state=config["random_state"], + scoring=make_scorer(breslow_score_wrapper), + cv=StratifiedSurvivalKFold( + n_splits=config["n_inner_cv"], + shuffle=config["shuffle_cv"], + random_state=config["random_state"], + ), + error_score=config["error_score"], + verbose=False, + n_iter=config["random_search_n_iter"], + ) + for split in range(25): + train_ix = train_splits.iloc[split, :].dropna().to_numpy().astype(int) + test_ix = test_splits.iloc[split, :].dropna().to_numpy().astype(int) + X_train = ( + X_.iloc[train_ix, :].copy().reset_index(drop=True).to_numpy(np.float32) + ) + y_train = y_[train_ix].copy() + y_test = y_[test_ix].copy() + X_test = ( + X_.iloc[test_ix, :].copy().reset_index(drop=True).to_numpy(np.float32) + ) + + teacher_cox_nnet.fit(X_train, y_train) + results[cancer][split] = teacher_cox_nnet.predict(X_test) + breslow = BreslowEstimator() + breslow.fit( + linear_predictor=teacher_cox_nnet.predict(X_train), + time=data["OS_days"][train_ix].values, + event=data["OS"][train_ix].values, + ) + ( + cumulative_baseline_hazards_times, + cumulative_baseline_hazards, + ) = ( + breslow.cum_baseline_hazard_.x, + breslow.cum_baseline_hazard_.y, + ) + cumulative_baseline_hazards = np.concatenate( + [np.array([0.0]), cumulative_baseline_hazards] + ) + cumulative_baseline_hazards_times: np.array = np.concatenate( + [np.array([0.0]), cumulative_baseline_hazards_times] + ) + cumulative_baseline_hazards: np.array = np.tile( + A=cumulative_baseline_hazards[ + np.digitize( + x=np.unique(y_test["time"]), + bins=cumulative_baseline_hazards_times, + right=False, + ) + - 1 + ], + reps=X_test.shape[0], + ).reshape((X_test.shape[0], np.unique(y_test["time"]).shape[0])) + log_hazards: np.array = ( + np.tile( + A=teacher_cox_nnet.predict(X_test).squeeze(), + reps=np.unique(y_test["time"]).shape[0], + ) + .reshape((np.unique(y_test["time"]).shape[0], X_test.shape[0])) + .T + ) + surv: pd.DataFrame = np.exp( + -pd.DataFrame( + cumulative_baseline_hazards * np.exp(log_hazards), + columns=np.unique(y_test["time"]), + ) + ) + surv.to_csv( + f"results/kd/cox_nnet/{cancer}/survival_function_teacher_{split+1}.csv", + index=False, + ) + + pd.concat( + [pd.DataFrame(results[cancer][i]) for i in range(25)], + axis=1, + ).to_csv( + f"results/kd/cox_nnet/{cancer}/eta_teacher.csv", + index=False, + ) diff --git a/workflow/scripts/py/time_sparsesurv.py b/workflow/scripts/py/time_sparsesurv.py new file mode 100644 index 0000000..bd7c048 --- /dev/null +++ b/workflow/scripts/py/time_sparsesurv.py @@ -0,0 +1,245 @@ +import sys + +with open(snakemake.log[0], "w") as f: + sys.stderr = sys.stdout = f + import json + from timeit import default_timer as timer + + import numpy as np + import pandas as pd + import torch + from sklearn.decomposition import PCA + from sklearn.feature_selection import VarianceThreshold + from sklearn.metrics import make_scorer + from sklearn.model_selection import GridSearchCV, RandomizedSearchCV + from sklearn.pipeline import make_pipeline + from sklearn.preprocessing import StandardScaler + from skorch.callbacks import EarlyStopping + from sksurv.linear_model import CoxPHSurvivalAnalysis + + from sparsesurv._base import KDSurv + from sparsesurv.cv import KDPHElasticNetCV + from sparsesurv.loss import breslow_negative_likelihood, efron_negative_likelihood + from sparsesurv.neuralsurv.python.model.model import SKORCH_MODULE_FACTORY + from sparsesurv.neuralsurv.python.model.skorch_infra import FixSeed + from sparsesurv.neuralsurv.python.utils.factories import ( + CRITERION_FACTORY, + SKORCH_NET_FACTORY, + ) + from sparsesurv.neuralsurv.python.utils.misc_utils import ( + StratifiedSkorchSurvivalSplit, + StratifiedSurvivalKFold, + ) + from sparsesurv.utils import inverse_transform_survival, transform_survival + + with open(snakemake.params["config_path"]) as f: + config = json.load(f) + + def breslow_score_wrapper(y_true, y_pred): + time, event = inverse_transform_survival(y_true) + return np.negative( + breslow_negative_likelihood( + linear_predictor=np.squeeze(y_pred.astype(np.float64)), + time=time, + event=event, + ) + ) + + def efron_score_wrapper(y_true, y_pred): + time, event = inverse_transform_survival(y_true) + return np.negative( + efron_negative_likelihood(linear_predictor=y_pred, time=time, event=event) + ) + + SCORE_FACTORY = {"breslow": breslow_score_wrapper, "efron": efron_score_wrapper} + + g = np.random.default_rng(config.get("random_state")) + np.random.seed(config["random_state"]) + + for tune_teacher in [True]: + for tune_l1_ratio in [False]: + for tie_correction in ["breslow"]: + timing = {} + for cancer in config["datasets"]: + timing[cancer] = [] + print(f"Starting: {cancer}") + train_splits = pd.read_csv( + f"results/make_splits/{cancer}_train_splits.csv" + ) + test_splits = pd.read_csv( + f"results/make_splits/{cancer}_test_splits.csv" + ) + data = pd.read_csv(f"results/preprocess_data/{cancer}.csv").iloc[ + :, 1: + ] + X_ = data.iloc[:, 3:].to_numpy() + y_ = transform_survival( + time=data["OS_days"].values, event=data["OS"].values + ) + for rep in range(config["timing_reps"]): + pipe = KDSurv( + teacher=GridSearchCV( + estimator=make_pipeline( + StandardScaler(), + PCA( + n_components=config["pc_n_components"], + random_state=config["random_state"], + ), + CoxPHSurvivalAnalysis(ties=tie_correction), + ), + param_grid={ + "pca__n_components": config[ + f"pc_n_components{'_tuned' if tune_teacher else ''}" + ] + }, + n_jobs=1, + verbose=0, + scoring=make_scorer(SCORE_FACTORY[tie_correction]), + cv=StratifiedSurvivalKFold( + n_splits=config["n_inner_cv"], + shuffle=config["shuffle_cv"], + random_state=config["random_state"], + ), + ), + student=make_pipeline( + VarianceThreshold(), + StandardScaler(), + KDPHElasticNetCV( + tie_correction=tie_correction, + l1_ratio=config[ + f"l1_ratio{'_tuned' if tune_l1_ratio else ''}" + ], + eps=config["eps"], + n_alphas=config["n_alphas"], + cv=config["n_inner_cv"], + stratify_cv=config["stratify_cv"], + seed=np.random.RandomState(config["random_state"]), + shuffle_cv=config["shuffle_cv"], + cv_score_method="linear_predictor", + n_jobs=1, + ), + ), + ) + start = timer() + pipe.fit(X_, y_) + end = timer() + timing[cancer].append(end - start) + if tune_l1_ratio: + pd.DataFrame(timing).to_csv( + f"results/kd/{tie_correction}/timing_tuned_l1_ratio{'_tuned_teacher' if tune_teacher else '' }.csv", + index=False, + ) + else: + pd.DataFrame(timing).to_csv( + f"results/kd/{tie_correction}/timing{'_tuned_teacher' if tune_teacher else '' }.csv", + index=False, + ) + + for tune_l1_ratio in [False]: + timing = {} + for cancer in config["datasets"]: + timing[cancer] = [] + print(f"Starting: {cancer}") + train_splits = pd.read_csv(f"results/make_splits/{cancer}_train_splits.csv") + test_splits = pd.read_csv(f"results/make_splits/{cancer}_test_splits.csv") + data = pd.read_csv(f"results/preprocess_data/{cancer}.csv").iloc[:, 1:] + X_ = data.iloc[:, 3:].to_numpy().astype(np.float32) + y_ = transform_survival( + time=data["OS_days"].values, event=data["OS"].values + ) + for rep in range(config["timing_reps"]): + pipe = KDSurv( + teacher=RandomizedSearchCV( + estimator=make_pipeline( + StandardScaler(), + SKORCH_NET_FACTORY["cox"]( + module=SKORCH_MODULE_FACTORY["cox"], + criterion=CRITERION_FACTORY["cox"], + module__fusion_method="early", + module__blocks=[[i for i in range(X_.shape[1])]], + iterator_train__shuffle=True, + optimizer=torch.optim.AdamW, + max_epochs=config["max_epochs"], + verbose=False, + train_split=StratifiedSkorchSurvivalSplit( + config["validation_set_neural"], + stratified=config["stratify_cv"], + random_state=config.get("random_state"), + ), + callbacks=[ + ( + "es", + EarlyStopping( + monitor="valid_loss", + patience=config["early_stopping_patience"], + load_best=True, + ), + ), + ("seed", FixSeed(generator=g)), + ], + module__activation=torch.nn.ReLU, + ), + ), + param_distributions={ + "coxphneuralnet__lr": config["tune_lr"], + "coxphneuralnet__optimizer__weight_decay": config[ + "tune_weight_decay" + ], + "coxphneuralnet__module__modality_hidden_layer_size": config[ + "tune_modality_hidden_layer_size" + ], + "coxphneuralnet__module__modality_hidden_layers": config[ + "tune_modality_hidden_layers" + ], + "coxphneuralnet__module__p_dropout": config[ + "tune_p_dropout" + ], + "coxphneuralnet__batch_size": config["tune_batch_size"], + }, + n_jobs=1, + random_state=config["random_state"], + scoring=make_scorer(breslow_score_wrapper), + cv=StratifiedSurvivalKFold( + n_splits=config["n_inner_cv"], + shuffle=config["shuffle_cv"], + random_state=config["random_state"], + ), + error_score=config["error_score"], + verbose=False, + n_iter=config["random_search_n_iter"], + ), + student=make_pipeline( + VarianceThreshold(), + StandardScaler(), + KDPHElasticNetCV( + tie_correction="breslow", + l1_ratio=config[ + f"l1_ratio{'_tuned' if tune_l1_ratio else ''}" + ], + eps=config["eps"], + n_alphas=config["n_alphas"], + cv=config["n_inner_cv"], + stratify_cv=config["stratify_cv"], + seed=np.random.RandomState(config["random_state"]), + shuffle_cv=config["shuffle_cv"], + cv_score_method="linear_predictor", + n_jobs=1, + ), + ), + ) + start = timer() + pipe.fit(X_, y_) + end = timer() + timing[cancer].append(end - start) + + if tune_l1_ratio: + pd.DataFrame(timing).to_csv( + f"results/kd/cox_nnet/timing_tuned_l1_ratio.csv", + index=False, + ) + + else: + pd.DataFrame(timing).to_csv( + f"results/kd/cox_nnet/timing.csv", + index=False, + ) diff --git a/workflow/scripts/r/make_table_S2.R b/workflow/scripts/r/make_table_S2.R new file mode 100644 index 0000000..1fb024e --- /dev/null +++ b/workflow/scripts/r/make_table_S2.R @@ -0,0 +1,74 @@ +log <- file(snakemake@log[[1]], open = "wt") +sink(log) + +suppressPackageStartupMessages({ + library(ggplot2) + library(dplyr) + library(vroom) + library(cowplot) + library(ggsignif) + library(readr) + library(tidyr) + options(warn = 1) +}) + +config <- rjson::fromJSON( + file = snakemake@params[["config_path"]] +) + +sparsity <- data.frame( + sparsity = c( + c( + unlist(as.vector(vroom::vroom( + paste( + "results", "non_kd", "breslow", "sparsity_vvh_lambda.min.csv", + sep = "/" + ) + )[1:25, ])) + ), + sapply(c( + "BLCA", + "BRCA", + "HNSC", + "KIRC", + "LGG", + "LIHC", + "LUAD", + "LUSC", + "OV", + "STAD" + ), function(cancer) { + unname(unlist(vroom::vroom(paste( + "results", "non_kd", "breslow", cancer, "sparsity_tuned_l1_ratio_vvh_lambda.min.csv", + sep = "/" + ), delim = ",")[, 1]))[1:25] + }), + unlist(as.vector(vroom::vroom( + paste( + "results", "kd", "breslow", "sparsity_linear_predictor_min.csv", + sep = "/" + ) + )[1:25, ])), + unlist(as.vector(vroom::vroom( + paste( + "results", "kd", "breslow", "sparsity_linear_predictor_pcvl.csv", + sep = "/" + ) + ))), + unlist(as.vector(vroom::vroom( + paste( + "results", "kd", "cox_nnet", "sparsity_linear_predictor_min.csv", + sep = "/" + ) + ))) + ), + cancer = rep(rep(config$datasets, each = 25), 5), + model = rep(c("glmnet (Breslow)", "glmnet tuned (Breslow)", "KD Breslow (min)", "KD Breslow (pcvl)", "KD Cox-Nnet (min)"), each = 250) +) + + +sparsity %>% + group_by(model, cancer) %>% + summarise(value = paste0(round(mean(sparsity), 2), " (", round(sd(sparsity), 2), ")")) %>% + pivot_wider(names_from = cancer, values_from = value) %>% + write_csv(paste("results", "tables", "table_S2.csv", sep = "/")) diff --git a/workflow/scripts/r/make_table_S3.R b/workflow/scripts/r/make_table_S3.R new file mode 100644 index 0000000..1fe0d14 --- /dev/null +++ b/workflow/scripts/r/make_table_S3.R @@ -0,0 +1,82 @@ +log <- file(snakemake@log[[1]], open = "wt") +sink(log) + +suppressPackageStartupMessages({ + library(ggplot2) + library(dplyr) + library(vroom) + library(cowplot) + library(ggsignif) + library(readr) + library(tidyr) + options(warn = 1) +}) + +config <- rjson::fromJSON( + file = snakemake@params[["config_path"]] +) + +tmp_glmnet <- (unname(unlist( + vroom::vroom(paste( + "results", "non_kd", "breslow", "failures_vvh_lambda.min.csv", + sep = "/" + ), delim = ",")[1, ] +))) + +failures <- data.frame(t(data.frame( + glmnet = c(tmp_glmnet[1], diff(tmp_glmnet)), + glmnet_tuned = sapply(c( + "BLCA", + "BRCA", + "HNSC", + "KIRC", + "LGG", + "LIHC", + "LUAD", + "LUSC", + "OV", + "STAD" + ), function(cancer) { + as.numeric((vroom::vroom(paste( + "results", "non_kd", "breslow", cancer, "failures_tuned_l1_ratio_vvh_lambda.min.csv", + sep = "/" + ), delim = ",")[1, 1])) + }), + kd_cox_nnet = unlist(unname( + vroom::vroom(paste( + "results", "kd", "cox_nnet", "failures_linear_predictor_min.csv", + sep = "/" + ), delim = ",")[, 1] + )), + kd_breslow = unlist(unname( + vroom::vroom(paste( + "results", "kd", "breslow", "failures_linear_predictor_min.csv", + sep = "/" + ), delim = ",")[, 1] + )), + kd_breslow_pcvl = unlist(unname( + vroom::vroom(paste( + "results", "kd", "breslow", "failures_linear_predictor_pcvl.csv", + sep = "/" + ), delim = ",")[, 1] + )) +))) + +colnames(failures) <- c( + "BLCA", + "BRCA", + "HNSC", + "KIRC", + "LGG", + "LIHC", + "LUAD", + "LUSC", + "OV", + "STAD" +) +failures$model <- c("glmnet (Breslow)", "glmnet tuned (Breslow)", "KD Breslow (min)", "KD Breslow (pcvl)", "KD Cox-Nnet (min)") + + +failures %>% + select(model, BLCA, BRCA, HNSC, KIRC, LGG, LIHC, LUAD, LUSC, OV, STAD) %>% + write_csv(paste("results", "tables", "table_S3.csv", sep = "/")) diff --git a/paper/scripts/r/make_table_S4.R b/workflow/scripts/r/make_table_S4.R similarity index 59% rename from paper/scripts/r/make_table_S4.R rename to workflow/scripts/r/make_table_S4.R index b16d706..f45d6e8 100644 --- a/paper/scripts/r/make_table_S4.R +++ b/workflow/scripts/r/make_table_S4.R @@ -1,40 +1,45 @@ -library(ggplot2) -library(dplyr) -library(vroom) -library(cowplot) -library(ggpubfigs) -library(ggsignif) -library(readr) -library(tidyr) -options(warn = 1) +log <- file(snakemake@log[[1]], open = "wt") +sink(log) +suppressPackageStartupMessages({ + library(ggplot2) + library(dplyr) + library(vroom) + library(cowplot) + library(ggsignif) + library(readr) + library(tidyr) + options(warn = 1) +}) config <- rjson::fromJSON( - file = here::here( - "config.json" - ) + file = snakemake@params[["config_path"]] ) timing <- data.frame( time = c( unlist(as.vector(vroom::vroom( - here::here( - "results", "non_kd", "breslow", "timing.csv" + paste( + "results", "non_kd", "breslow", "timing.csv", + sep = "/" ) ))), unlist(as.vector(vroom::vroom( - here::here( - "results", "non_kd", "breslow", "timing_tuned_l1_ratio.csv" + paste( + "results", "non_kd", "breslow", "timing_tuned_l1_ratio.csv", + sep = "/" ) ))), unlist(as.vector(vroom::vroom( - here::here( - "results", "kd", "breslow", "timing_tuned_teacher.csv" + paste( + "results", "kd", "breslow", "timing_tuned_teacher.csv", + sep = "/" ) ))), unlist(as.vector(vroom::vroom( - here::here( - "results", "kd", "cox_nnet", "timing.csv" + paste( + "results", "kd", "cox_nnet", "timing.csv", + sep = "/" ) ))) ), @@ -49,4 +54,4 @@ timing %>% group_by(model, cancer) %>% summarise(value = paste0(round(mean(time), 2), " (", round(sd(time), 2), ")")) %>% pivot_wider(names_from = cancer, values_from = value) %>% - write_csv(here::here("tables", "table_S4.csv")) + write_csv(paste("results", "tables", "table_S4.csv", sep = "/")) diff --git a/paper/scripts/r/plot_figure_1.R b/workflow/scripts/r/plot_figure_1.R similarity index 79% rename from paper/scripts/r/plot_figure_1.R rename to workflow/scripts/r/plot_figure_1.R index 272b737..f5bc475 100644 --- a/paper/scripts/r/plot_figure_1.R +++ b/workflow/scripts/r/plot_figure_1.R @@ -1,32 +1,69 @@ -library(ggplot2) -library(dplyr) -library(vroom) -library(cowplot) -library(ggpubfigs) -library(ggsignif) -library(readr) -library(tidyr) -options(warn = 1) - +log <- file(snakemake@log[[1]], open = "wt") +sink(log) + +suppressPackageStartupMessages({ + library(ggplot2) + library(dplyr) + library(vroom) + library(cowplot) + library(ggsignif) + library(readr) + library(tidyr) + options(warn = 1) +}) config <- rjson::fromJSON( - file = here::here( - "config.json" - ) + file = snakemake@params[["config_path"]] ) +theme_big_simple <- function() { + theme_bw(base_size = 16, base_family = "") %+replace% + theme( + plot.background = element_rect(fill = "transparent", colour = NA), + legend.background = element_rect(fill = "transparent", colour = NA), + legend.key = element_rect(fill = "transparent", colour = NA), + legend.title = element_text(size = 24), + legend.text = element_text(size = 20), + axis.line = element_line(color = "black", size = 1, linetype = "solid"), + axis.ticks = element_line(colour = "black", size = 1), + panel.background = element_blank(), + panel.grid.minor = element_blank(), + panel.grid.major = element_blank(), + panel.border = element_blank(), + legend.position = "bottom", + plot.title = element_text(size = 24, hjust = 0.0, vjust = 1.75), + axis.text.x = element_text(color = "black", size = 20, margin = margin(t = 4, r = 0, b = 0, l = 0)), + axis.text.y = element_text(color = "black", size = 20, margin = margin(t = 0, r = 4, b = 0, l = 0)), + axis.title.y = element_text(margin = margin(t = 0, r = 10, b = 0, l = 0), angle = 90, size = 24), + axis.title.x = element_text(margin = margin(t = 10, r = 0, b = 0, l = 0), angle = 0, size = 24), + axis.ticks.length = unit(0.20, "cm"), + strip.background = element_rect(color = "black", size = 1, linetype = "solid"), + strip.text.x = element_text(size = 20, color = "black"), + strip.text.y = element_text(size = 20, color = "black") + ) +} + +friendly_pals <- list( + bright_seven = c("#4477AA", "#228833", "#AA3377", "#BBBBBB", "#66CCEE", "#CCBB44", "#EE6677"), + contrast_three = c("#004488", "#BB5566", "#DDAA33"), + vibrant_seven = c("#0077BB", "#EE7733", "#33BBEE", "#CC3311", "#009988", "#EE3377", "#BBBBBB"), + muted_nine = c("#332288", "#117733", "#CC6677", "#88CCEE", "#999933", "#882255", "#44AA99", "#DDCC77", "#AA4499"), + nickel_five = c("#648FFF", "#FE6100", "#785EF0", "#FFB000", "#DC267F"), + ito_seven = c("#0072B2", "#D55E00", "#009E73", "#CC79A7", "#56B4E9", "#E69F00", "#F0E442"), + ibm_five = c("#648FFF", "#785EF0", "#DC267F", "#FE6100", "#FFB000"), + wong_eight = c("#E69F00", "#56B4E9", "#009E73", "#F0E442", "#0072B2", "#D55E00", "#CC79A7", "#000000"), + tol_eight = c("#332288", "#117733", "#44AA99", "#88CCEE", "#DDCC77", "#CC6677", "#AA4499", "#882255"), + zesty_four = c("#F5793A", "#A95AA1", "#85C0F9", "#0F2080"), + retro_four = c("#601A4A", "#EE442F", "#63ACBE", "#F9F4EC") +) - - -metrics <- vroom::vroom(here::here("results", "metrics", "metrics_overall.csv")) - +metrics <- vroom::vroom(paste("results", "metrics", "metrics_overall.csv", sep = "/")) fig_1_ab <- metrics %>% filter(model %in% c("breslow", "cox_nnet")) %>% filter(lambda %in% c("min", "lambda.min", "pcvl")) %>% filter(metric %in% c("Harrell's C", "Uno's C")) - fig_1_ab$model_type <- ifelse(fig_1_ab$model == "breslow" & fig_1_ab$lambda == "pcvl", "KD Breslow (pcvl)", ifelse(fig_1_ab$model == "breslow" & fig_1_ab$kd, "KD Breslow (min)", @@ -113,7 +150,7 @@ a <- fig_1_ab %>% axis.text.x = element_blank(), axis.ticks.x = element_blank() ) + - scale_fill_manual(values = ggpubfigs::friendly_pals$ito_seven[c(1, 2, 4, 5, 6, 7)]) + + scale_fill_manual(values = friendly_pals$ito_seven[c(1, 2, 4, 5, 6, 7)]) + geom_signif( data = signif_frame, aes(xmin = start, xmax = end, annotations = pval, y_position = c(1.0, 1.05, 1.1, 0.95)), @@ -165,7 +202,7 @@ b <- fig_1_ab %>% axis.text.x = element_blank(), axis.ticks.x = element_blank() ) + - scale_fill_manual(values = ggpubfigs::friendly_pals$ito_seven[c(1, 2, 4, 5, 6, 7)]) + + scale_fill_manual(values = friendly_pals$ito_seven[c(1, 2, 4, 5, 6, 7)]) + geom_signif( data = signif_frame, aes(xmin = start, xmax = end, annotations = pval, y_position = c(1.1, 1.175, 1.25, 1.025)), @@ -216,7 +253,7 @@ c <- fig_1_cd %>% axis.text.x = element_blank(), axis.ticks.x = element_blank() ) + - scale_fill_manual(values = ggpubfigs::friendly_pals$ito_seven[c(1, 2, 4, 5, 6, 7)]) + + scale_fill_manual(values = friendly_pals$ito_seven[c(1, 2, 4, 5, 6, 7)]) + geom_signif( data = signif_frame, aes(xmin = start, xmax = end, annotations = pval, y_position = c(1.0, 1.05, 1.1, 0.95)), @@ -236,7 +273,7 @@ c_legend <- fig_1_cd %>% axis.text.x = element_blank(), axis.ticks.x = element_blank() ) + - scale_fill_manual(values = ggpubfigs::friendly_pals$ito_seven[c(1, 2, 4, 5, 6, 7)]) + + scale_fill_manual(values = friendly_pals$ito_seven[c(1, 2, 4, 5, 6, 7)]) + geom_signif( data = signif_frame, aes(xmin = start, xmax = end, annotations = pval, y_position = c(1.0, 1.05, 1.1, 0.95)), @@ -288,7 +325,7 @@ d <- fig_1_cd %>% axis.text.x = element_blank(), axis.ticks.x = element_blank() ) + - scale_fill_manual(values = ggpubfigs::friendly_pals$ito_seven[c(1, 2, 4, 5, 6, 7)]) + + scale_fill_manual(values = friendly_pals$ito_seven[c(1, 2, 4, 5, 6, 7)]) + geom_signif( data = signif_frame, aes(xmin = start, xmax = end, annotations = pval, y_position = c(0.375, 0.4, 0.425, 0.35)), @@ -301,23 +338,27 @@ d <- fig_1_cd %>% timing <- data.frame( time = c( unlist(as.vector(vroom::vroom( - here::here( - "results", "non_kd", "breslow", "timing.csv" + paste( + "results", "non_kd", "breslow", "timing.csv", + sep = "/" ) ))), unlist(as.vector(vroom::vroom( - here::here( - "results", "non_kd", "breslow", "timing_tuned_l1_ratio.csv" + paste( + "results", "non_kd", "breslow", "timing_tuned_l1_ratio.csv", + sep = "/" ) ))), unlist(as.vector(vroom::vroom( - here::here( - "results", "kd", "breslow", "timing_tuned_teacher.csv" + paste( + "results", "kd", "breslow", "timing_tuned_teacher.csv", + sep = "/" ) ))), unlist(as.vector(vroom::vroom( - here::here( - "results", "kd", "cox_nnet", "timing.csv" + paste( + "results", "kd", "cox_nnet", "timing.csv", + sep = "/" ) ))) ), @@ -341,8 +382,8 @@ cancer_ordering <- timing %>% f <- ggplot(timing_summarised, aes(x = cancer, group = model)) + geom_line(aes(y = mean, color = model), linewidth = 1) + - scale_color_manual(values = ggpubfigs::friendly_pals$ito_seven[c(1, 2, 4, 6)]) + - scale_fill_manual(values = ggpubfigs::friendly_pals$ito_seven[c(1, 2, 4, 6)]) + + scale_color_manual(values = friendly_pals$ito_seven[c(1, 2, 4, 6)]) + + scale_fill_manual(values = friendly_pals$ito_seven[c(1, 2, 4, 6)]) + theme_big_simple() + labs(x = "", y = "Time (s)", fill = "", color = "") + theme(axis.text.x = element_text(angle = 90, vjust = 0.5, hjust = 1)) @@ -351,44 +392,59 @@ f <- ggplot(timing_summarised, aes(x = cancer, group = model)) + f_legend <- ggplot(timing_summarised, aes(x = cancer, group = model)) + geom_line(aes(y = mean, color = model), linewidth = 1) + geom_ribbon(aes(y = mean, ymin = mean - sd, ymax = mean + sd, fill = model), alpha = .1) + - scale_color_manual(values = ggpubfigs::friendly_pals$ito_seven[c(1, 2, 4, 5, 6, 7)]) + - scale_fill_manual(values = ggpubfigs::friendly_pals$ito_seven[c(1, 2, 4, 5, 6, 7)]) + + scale_color_manual(values = friendly_pals$ito_seven[c(1, 2, 4, 5, 6, 7)]) + + scale_fill_manual(values = friendly_pals$ito_seven[c(1, 2, 4, 5, 6, 7)]) + theme_big_simple() + labs(x = "", y = "Time (s)", fill = "", color = "") - sparsity <- data.frame( sparsity = c( c( unlist(as.vector(vroom::vroom( - here::here( - "results", "non_kd", "breslow", "sparsity_vvh_lambda.min.csv" + paste( + "results", "non_kd", "breslow", "sparsity_vvh_lambda.min.csv", + sep = "/" ) )[1:25, ])) ), + sapply(c( + "BLCA", + "BRCA", + "HNSC", + "KIRC", + "LGG", + "LIHC", + "LUAD", + "LUSC", + "OV", + "STAD" + ), function(cancer) { + unname(unlist(vroom::vroom(paste( + "results", "non_kd", "breslow", cancer, "sparsity_tuned_l1_ratio_vvh_lambda.min.csv", + sep = "/" + ), delim = ",")[, 1]))[1:25] + }), unlist(as.vector(vroom::vroom( - here::here( - "results", "non_kd", "breslow", "sparsity_tuned_l1_ratio_vvh_lambda.min.csv" - ) - ))), - unlist(as.vector(vroom::vroom( - here::here( - "results", "kd", "breslow", "sparsity_linear_predictor_min.csv" + paste( + "results", "kd", "breslow", "sparsity_linear_predictor_min.csv", + sep = "/" ) )[1:25, ])), unlist(as.vector(vroom::vroom( - here::here( - "results", "kd", "breslow", "sparsity_linear_predictor_pcvl.csv" + paste( + "results", "kd", "breslow", "sparsity_linear_predictor_pcvl.csv", + sep = "/" ) ))), unlist(as.vector(vroom::vroom( - here::here( - "results", "kd", "cox_nnet", "sparsity_linear_predictor_min.csv" + paste( + "results", "kd", "cox_nnet", "sparsity_linear_predictor_min.csv", + sep = "/" ) ))) ), - cancer = rep(rep(config$datasets, each = 25), 9), - model = rep(c("glmnet (Breslow)", "glmnet tuned (Breslow)", "KD Breslow (min)", "KD Breslow (pcvl)", "KD Cox-Nnet (min)"), each = 225) + cancer = rep(rep(config$datasets, each = 25), 5), + model = rep(c("glmnet (Breslow)", "glmnet tuned (Breslow)", "KD Breslow (min)", "KD Breslow (pcvl)", "KD Cox-Nnet (min)"), each = 250) ) sparsity$cancer <- factor(sparsity$cancer, levels = sparsity %>% group_by(cancer) %>% summarise(mean = mean(sparsity)) %>% arrange(desc(`mean`)) %>% pull(cancer)) @@ -407,10 +463,10 @@ e <- sparsity %>% ggplot(aes(x = model, y = sparsity, fill = model)) + axis.text.x = element_blank(), axis.ticks.x = element_blank() ) + - scale_fill_manual(values = ggpubfigs::friendly_pals$ito_seven[c(1, 2, 4, 5, 6, 7)]) + scale_fill_manual(values = friendly_pals$ito_seven[c(1, 2, 4, 5, 6, 7)]) -metrics <- vroom::vroom(here::here("results", "metrics", "metrics_overall_fixed.csv")) +metrics <- vroom::vroom(paste("results", "metrics", "metrics_overall.csv", sep = "/")) metrics %>% filter(score %in% c("path")) %>% @@ -447,8 +503,8 @@ path_data_summarised$cancer <- factor(path_data_summarised$cancer, levels = canc g <- ggplot(path_data_summarised, aes(x = as.numeric(lambda), group = model_type)) + geom_line(aes(y = mean, color = model_type), linewidth = 1) + geom_ribbon(aes(y = mean, ymin = mean - sd, ymax = mean + sd, fill = model_type), alpha = .1) + - scale_color_manual(values = ggpubfigs::friendly_pals$ito_seven[c(1, 4)]) + - scale_fill_manual(values = ggpubfigs::friendly_pals$ito_seven[c(1, 4)]) + + scale_color_manual(values = friendly_pals$ito_seven[c(1, 4)]) + + scale_fill_manual(values = friendly_pals$ito_seven[c(1, 4)]) + geom_hline(data = teacher_line, aes(yintercept = mean, linetype = "Breslow teacher"), color = "red", lwd = 0.5, linetype = 2, show.legend = FALSE, alpha = 0.75) + facet_wrap(~cancer, scales = "free_y", nrow = 2) + theme_big_simple() + @@ -463,7 +519,7 @@ teacher_legend <- p + geom_hline(aes(lty = "Breslow teacher", yintercept = 20), -metrics <- vroom::vroom(here::here("results", "metrics", "metrics_overall_fixed.csv")) +metrics <- vroom::vroom(paste("results", "metrics", "metrics_overall.csv", sep = "/")) metrics %>% filter(score %in% c("path")) %>% @@ -492,8 +548,8 @@ path_data_summarised <- path_data %>% h <- ggplot(path_data_summarised, aes(x = as.numeric(lambda), group = model_type)) + geom_line(aes(y = mean, color = model_type), linewidth = 1) + geom_ribbon(aes(y = mean, ymin = mean - sd, ymax = mean + sd, fill = model_type), alpha = .1) + - scale_color_manual(values = ggpubfigs::friendly_pals$ito_seven[c(1, 4)]) + - scale_fill_manual(values = ggpubfigs::friendly_pals$ito_seven[c(1, 4)]) + + scale_color_manual(values = friendly_pals$ito_seven[c(1, 4)]) + + scale_fill_manual(values = friendly_pals$ito_seven[c(1, 4)]) + geom_hline(data = teacher_line, aes(yintercept = mean, linetype = "Breslow teacher"), color = "red", lwd = 0.5, linetype = 2, show.legend = FALSE, alpha = 0.75) + facet_wrap(~cancer, scales = "free_y", nrow = 2) + theme_big_simple() + @@ -532,5 +588,5 @@ panels <- plot_grid( label_size = 24 ) -ggsave(here::here("figures", "fig-1_finalized.png"), plot = panels, dpi = 300, height = 20, width = 15, units = "in") -ggsave(here::here("figures", "fig-1_finalized.svg"), plot = panels, dpi = 300, height = 20, width = 15, units = "in") +ggsave(paste("results", "figures", "fig-1_finalized.png", sep = "/"), plot = panels, dpi = 300, height = 20, width = 15, units = "in") +ggsave(paste("results", "figures", "fig-1_finalized.svg", sep = "/"), plot = panels, dpi = 300, height = 20, width = 15, units = "in") diff --git a/paper/scripts/r/plot_figure_S1.R b/workflow/scripts/r/plot_figure_S1.R similarity index 87% rename from paper/scripts/r/plot_figure_S1.R rename to workflow/scripts/r/plot_figure_S1.R index 031e16e..ddd9eb8 100644 --- a/paper/scripts/r/plot_figure_S1.R +++ b/workflow/scripts/r/plot_figure_S1.R @@ -1,31 +1,70 @@ -library(ggplot2) -library(dplyr) -library(vroom) -library(cowplot) -library(ggpubfigs) -library(ggsignif) -options(warn = 1) - +log <- file(snakemake@log[[1]], open = "wt") +sink(log) + +suppressPackageStartupMessages({ + library(ggplot2) + library(dplyr) + library(vroom) + library(cowplot) + library(ggsignif) + options(warn = 1) +}) config <- rjson::fromJSON( - file = here::here( - "config.json" - ) + file = snakemake@params[["config_path"]] ) +theme_big_simple <- function() { + theme_bw(base_size = 16, base_family = "") %+replace% + theme( + plot.background = element_rect(fill = "transparent", colour = NA), + legend.background = element_rect(fill = "transparent", colour = NA), + legend.key = element_rect(fill = "transparent", colour = NA), + legend.title = element_text(size = 24), + legend.text = element_text(size = 20), + axis.line = element_line(color = "black", size = 1, linetype = "solid"), + axis.ticks = element_line(colour = "black", size = 1), + panel.background = element_blank(), + panel.grid.minor = element_blank(), + panel.grid.major = element_blank(), + panel.border = element_blank(), + legend.position = "bottom", + plot.title = element_text(size = 24, hjust = 0.0, vjust = 1.75), + axis.text.x = element_text(color = "black", size = 20, margin = margin(t = 4, r = 0, b = 0, l = 0)), + axis.text.y = element_text(color = "black", size = 20, margin = margin(t = 0, r = 4, b = 0, l = 0)), + axis.title.y = element_text(margin = margin(t = 0, r = 10, b = 0, l = 0), angle = 90, size = 24), + axis.title.x = element_text(margin = margin(t = 10, r = 0, b = 0, l = 0), angle = 0, size = 24), + axis.ticks.length = unit(0.20, "cm"), + strip.background = element_rect(color = "black", size = 1, linetype = "solid"), + strip.text.x = element_text(size = 20, color = "black"), + strip.text.y = element_text(size = 20, color = "black") + ) +} + +friendly_pals <- list( + bright_seven = c("#4477AA", "#228833", "#AA3377", "#BBBBBB", "#66CCEE", "#CCBB44", "#EE6677"), + contrast_three = c("#004488", "#BB5566", "#DDAA33"), + vibrant_seven = c("#0077BB", "#EE7733", "#33BBEE", "#CC3311", "#009988", "#EE3377", "#BBBBBB"), + muted_nine = c("#332288", "#117733", "#CC6677", "#88CCEE", "#999933", "#882255", "#44AA99", "#DDCC77", "#AA4499"), + nickel_five = c("#648FFF", "#FE6100", "#785EF0", "#FFB000", "#DC267F"), + ito_seven = c("#0072B2", "#D55E00", "#009E73", "#CC79A7", "#56B4E9", "#E69F00", "#F0E442"), + ibm_five = c("#648FFF", "#785EF0", "#DC267F", "#FE6100", "#FFB000"), + wong_eight = c("#E69F00", "#56B4E9", "#009E73", "#F0E442", "#0072B2", "#D55E00", "#CC79A7", "#000000"), + tol_eight = c("#332288", "#117733", "#44AA99", "#88CCEE", "#DDCC77", "#CC6677", "#AA4499", "#882255"), + zesty_four = c("#F5793A", "#A95AA1", "#85C0F9", "#0F2080"), + retro_four = c("#601A4A", "#EE442F", "#63ACBE", "#F9F4EC") +) -metrics <- vroom::vroom(here::here("results", "metrics", "metrics_overall.csv")) %>% filter(score != "teacher") -metrics_teacher <- vroom::vroom(here::here("results", "metrics", "metrics_overall_teachers.csv")) +metrics <- vroom::vroom(paste("results", "metrics", "metrics_overall.csv", sep = "/")) %>% filter(score != "teacher") +metrics_teacher <- vroom::vroom(paste("results", "metrics", "metrics_overall_teachers.csv", sep = "/")) metrics <- rbind(metrics, metrics_teacher) - fig_1_ab <- metrics %>% filter(model %in% c("breslow", "cox_nnet")) %>% filter(lambda %in% c("min", "lambda.min", "pcvl", "teacher")) %>% filter(metric %in% c("Harrell's C", "Uno's C")) %>% filter(!(lambda == "teacher" & model == "cox_nnet")) - fig_1_ab$model_type <- ifelse(fig_1_ab$model == "breslow" & fig_1_ab$lambda == "pcvl", "KD Breslow (pcvl)", ifelse(fig_1_ab$model == "breslow" & fig_1_ab$kd & !fig_1_ab$lambda == "teacher", "KD Breslow (min)", @@ -133,7 +172,7 @@ a <- fig_1_ab %>% axis.text.x = element_blank(), axis.ticks.x = element_blank() ) + - scale_fill_manual(values = c(ggpubfigs::friendly_pals$ito_seven[c(1, 2, 4, 5, 6, 7)])) + + scale_fill_manual(values = c(friendly_pals$ito_seven[c(1, 2, 4, 5, 6, 7)])) + geom_signif( data = signif_frame, aes(xmin = start, xmax = end, annotations = pval, y_position = c(1.16, 1.11, 1.06, 1.01, 0.95)), @@ -191,7 +230,7 @@ b <- fig_1_ab %>% axis.text.x = element_blank(), axis.ticks.x = element_blank() ) + - scale_fill_manual(values = c(ggpubfigs::friendly_pals$ito_seven[c(1, 2, 4, 5, 6, 7)])) + + scale_fill_manual(values = c(friendly_pals$ito_seven[c(1, 2, 4, 5, 6, 7)])) + geom_signif( data = signif_frame, aes(xmin = start, xmax = end, annotations = pval, y_position = c(1.25, 1.175, 1.1, 1.025, 0.95)), @@ -248,7 +287,7 @@ c <- fig_1_cd %>% axis.text.x = element_blank(), axis.ticks.x = element_blank() ) + - scale_fill_manual(values = c(ggpubfigs::friendly_pals$ito_seven[c(1, 2, 4, 5, 6, 7, 3)])) + + scale_fill_manual(values = c(friendly_pals$ito_seven[c(1, 2, 4, 5, 6, 7, 3)])) + geom_signif( data = signif_frame, aes(xmin = start, xmax = end, annotations = pval, y_position = c(1.25, 1.175, 1.1, 1.025, 0.95)), @@ -267,7 +306,7 @@ c_legend <- fig_1_cd %>% axis.text.x = element_blank(), axis.ticks.x = element_blank() ) + - scale_fill_manual(values = c(ggpubfigs::friendly_pals$ito_seven[c(1, 2, 4, 5, 6, 7)])) + + scale_fill_manual(values = c(friendly_pals$ito_seven[c(1, 2, 4, 5, 6, 7)])) + geom_signif( data = signif_frame, aes(xmin = start, xmax = end, annotations = pval, y_position = c(1.15, 1.125, 1.075, 1.025, 0.95)), @@ -324,7 +363,7 @@ d <- fig_1_cd %>% axis.text.x = element_blank(), axis.ticks.x = element_blank() ) + - scale_fill_manual(values = c(ggpubfigs::friendly_pals$ito_seven[c(1, 2, 4, 5, 6, 7)])) + + scale_fill_manual(values = c(friendly_pals$ito_seven[c(1, 2, 4, 5, 6, 7)])) + geom_signif( data = signif_frame, aes(xmin = start, xmax = end, annotations = pval, y_position = (rev(c(0.35, 0.385, 0.41, 0.435, 0.46))) + 0.025), @@ -337,8 +376,8 @@ d <- fig_1_cd %>% -metrics <- vroom::vroom(here::here("results", "metrics", "metrics_overall.csv")) %>% filter(score != "teacher") -metrics_teacher <- vroom::vroom(here::here("results", "metrics", "metrics_overall_teachers.csv")) +metrics <- vroom::vroom(paste("results", "metrics", "metrics_overall.csv", sep = "/")) %>% filter(score != "teacher") +metrics_teacher <- vroom::vroom(paste("results", "metrics", "metrics_overall_teachers.csv", sep = "/")) metrics <- rbind(metrics, metrics_teacher) @@ -456,7 +495,7 @@ a_bottom <- fig_1_ab %>% axis.text.x = element_blank(), axis.ticks.x = element_blank() ) + - scale_fill_manual(values = c(ggpubfigs::friendly_pals$ito_seven[c(1, 2, 4, 5, 6, 3)])) + + scale_fill_manual(values = c(friendly_pals$ito_seven[c(1, 2, 4, 5, 6, 3)])) + geom_signif( data = signif_frame, aes(xmin = start, xmax = end, annotations = pval, y_position = c(1.16, 1.11, 1.06, 1.01, 0.95)), @@ -514,7 +553,7 @@ b_bottom <- fig_1_ab %>% axis.text.x = element_blank(), axis.ticks.x = element_blank() ) + - scale_fill_manual(values = c(ggpubfigs::friendly_pals$ito_seven[c(1, 2, 4, 5, 6, 3)])) + + scale_fill_manual(values = c(friendly_pals$ito_seven[c(1, 2, 4, 5, 6, 3)])) + geom_signif( data = signif_frame, aes(xmin = start, xmax = end, annotations = pval, y_position = (c(1.275, 1.2, 1.125, 1.05, 0.95)) + 0.1), @@ -571,7 +610,7 @@ c_bottom <- fig_1_cd %>% axis.text.x = element_blank(), axis.ticks.x = element_blank() ) + - scale_fill_manual(values = c(ggpubfigs::friendly_pals$ito_seven[c(1, 2, 4, 5, 6, 3)])) + + scale_fill_manual(values = c(friendly_pals$ito_seven[c(1, 2, 4, 5, 6, 3)])) + geom_signif( data = signif_frame, aes(xmin = start, xmax = end, annotations = pval, y_position = c(1.16, 1.11, 1.06, 1.01, 0.95)), @@ -590,7 +629,7 @@ c_bottom_legend <- fig_1_cd %>% axis.text.x = element_blank(), axis.ticks.x = element_blank() ) + - scale_fill_manual(values = c(ggpubfigs::friendly_pals$ito_seven[c(1, 2, 4, 5, 6, 3)])) + + scale_fill_manual(values = c(friendly_pals$ito_seven[c(1, 2, 4, 5, 6, 3)])) + geom_signif( data = signif_frame, aes(xmin = start, xmax = end, annotations = pval, y_position = c(1.16, 1.11, 1.06, 1.01, 0.95)), @@ -647,7 +686,7 @@ d_bottom <- fig_1_cd %>% axis.text.x = element_blank(), axis.ticks.x = element_blank() ) + - scale_fill_manual(values = c(ggpubfigs::friendly_pals$ito_seven[c(1, 2, 4, 5, 6, 3)])) + + scale_fill_manual(values = c(friendly_pals$ito_seven[c(1, 2, 4, 5, 6, 3)])) + geom_signif( data = signif_frame, aes(xmin = start, xmax = end, annotations = pval, y_position = (rev(c(0.35, 0.385, 0.41, 0.435, 0.46))) + 0.025), @@ -670,5 +709,5 @@ second_row_with_legend <- cowplot::plot_grid(second_row, boxplot_legend, rel_hei s1 <- cowplot::plot_grid(first_row_with_legend, second_row_with_legend, nrow = 2) -ggsave(here::here("figures", "fig-S1_finalized.pdf"), plot = s1, dpi = 300, height = 20 / 1.5, width = 15, units = "in") -ggsave(here::here("figures", "fig-S1_finalized.svg"), plot = s1, dpi = 300, height = 20 / 1.5, width = 15, units = "in") +ggsave(paste("results", "figures", "fig-S1_finalized.pdf", sep = "/"), plot = s1, dpi = 300, height = 20 / 1.5, width = 15, units = "in") +ggsave(paste("results", "figures", "fig-S1_finalized.svg", sep = "/"), plot = s1, dpi = 300, height = 20 / 1.5, width = 15, units = "in") diff --git a/workflow/scripts/r/plot_figure_S2.R b/workflow/scripts/r/plot_figure_S2.R new file mode 100644 index 0000000..d6b08d0 --- /dev/null +++ b/workflow/scripts/r/plot_figure_S2.R @@ -0,0 +1,203 @@ +log <- file(snakemake@log[[1]], open = "wt") +sink(log) + +suppressPackageStartupMessages({ + library(ggplot2) + library(dplyr) + library(vroom) + library(cowplot) + library(ggsignif) + options(warn = 1) +}) + +config <- rjson::fromJSON( + file = snakemake@params[["config_path"]] +) + +theme_big_simple <- function() { + theme_bw(base_size = 16, base_family = "") %+replace% + theme( + plot.background = element_rect(fill = "transparent", colour = NA), + legend.background = element_rect(fill = "transparent", colour = NA), + legend.key = element_rect(fill = "transparent", colour = NA), + legend.title = element_text(size = 24), + legend.text = element_text(size = 20), + axis.line = element_line(color = "black", size = 1, linetype = "solid"), + axis.ticks = element_line(colour = "black", size = 1), + panel.background = element_blank(), + panel.grid.minor = element_blank(), + panel.grid.major = element_blank(), + panel.border = element_blank(), + legend.position = "bottom", + plot.title = element_text(size = 24, hjust = 0.0, vjust = 1.75), + axis.text.x = element_text(color = "black", size = 20, margin = margin(t = 4, r = 0, b = 0, l = 0)), + axis.text.y = element_text(color = "black", size = 20, margin = margin(t = 0, r = 4, b = 0, l = 0)), + axis.title.y = element_text(margin = margin(t = 0, r = 10, b = 0, l = 0), angle = 90, size = 24), + axis.title.x = element_text(margin = margin(t = 10, r = 0, b = 0, l = 0), angle = 0, size = 24), + axis.ticks.length = unit(0.20, "cm"), + strip.background = element_rect(color = "black", size = 1, linetype = "solid"), + strip.text.x = element_text(size = 20, color = "black"), + strip.text.y = element_text(size = 20, color = "black") + ) +} + +friendly_pals <- list( + bright_seven = c("#4477AA", "#228833", "#AA3377", "#BBBBBB", "#66CCEE", "#CCBB44", "#EE6677"), + contrast_three = c("#004488", "#BB5566", "#DDAA33"), + vibrant_seven = c("#0077BB", "#EE7733", "#33BBEE", "#CC3311", "#009988", "#EE3377", "#BBBBBB"), + muted_nine = c("#332288", "#117733", "#CC6677", "#88CCEE", "#999933", "#882255", "#44AA99", "#DDCC77", "#AA4499"), + nickel_five = c("#648FFF", "#FE6100", "#785EF0", "#FFB000", "#DC267F"), + ito_seven = c("#0072B2", "#D55E00", "#009E73", "#CC79A7", "#56B4E9", "#E69F00", "#F0E442"), + ibm_five = c("#648FFF", "#785EF0", "#DC267F", "#FE6100", "#FFB000"), + wong_eight = c("#E69F00", "#56B4E9", "#009E73", "#F0E442", "#0072B2", "#D55E00", "#CC79A7", "#000000"), + tol_eight = c("#332288", "#117733", "#44AA99", "#88CCEE", "#DDCC77", "#CC6677", "#AA4499", "#882255"), + zesty_four = c("#F5793A", "#A95AA1", "#85C0F9", "#0F2080"), + retro_four = c("#601A4A", "#EE442F", "#63ACBE", "#F9F4EC") +) + +timing <- data.frame( + time = c( + unlist(as.vector(vroom::vroom( + paste( + "results", "non_kd", "breslow", "timing.csv", + sep = "/" + ) + ))), + unlist(as.vector(vroom::vroom( + paste( + "results", "non_kd", "breslow", "timing_tuned_l1_ratio.csv", + sep = "/" + ) + ))), + unlist(as.vector(vroom::vroom( + paste( + "results", "kd", "breslow", "timing_tuned_teacher.csv", + sep = "/" + ) + ))), + unlist(as.vector(vroom::vroom( + paste( + "results", "kd", "cox_nnet", "timing.csv", + sep = "/" + ) + ))) + ), + cancer = rep(rep(config$datasets, each = 5), 4), + model = rep(c("glmnet (Breslow)", "glmnet tuned (Breslow)", "KD Breslow", "KD Cox-Nnet"), each = 50) +) + +timing$cancer <- factor(timing$cancer, levels = timing %>% group_by(cancer) %>% summarise(mean = mean(time)) %>% arrange(desc(`mean`)) %>% pull(cancer)) +timing$model <- factor(timing$model, levels = c("glmnet (Breslow)", "glmnet tuned (Breslow)", "KD Breslow", "KD Cox-Nnet")) +timing_summarised <- timing %>% + group_by(cancer, model) %>% + summarise(mean = mean(time), sd = sd(time) / sqrt(n())) + +cancer_ordering <- timing %>% + group_by(cancer) %>% + summarise(mean = mean(time)) %>% + arrange(desc(`mean`)) %>% + pull(cancer) + +metrics <- vroom::vroom(paste("results", "metrics", "metrics_overall.csv", sep = "/")) + +metrics %>% + filter(score %in% c("path")) %>% + filter((metric == "Antolini's C" & model == "breslow" & !kd) | (kd & model == "cox_nnet" & metric == "Antolini's C")) -> path_data +teacher_line <- metrics %>% + filter(score %in% c("teacher")) %>% + filter(metric == "Antolini's C") %>% + filter(model == "cox_nnet") %>% + group_by(cancer) %>% + summarise(mean = mean(value)) +path_data$cancer <- factor(path_data$cancer, levels = cancer_ordering) + + +path_data$model_type <- ifelse(path_data$kd, "KD Cox-Nnet", + "glmnet (Breslow)" +) + + +path_data$model_type <- factor(path_data$model_type, levels = c("glmnet (Breslow)", "KD Cox-Nnet")) +path_data$cancer <- factor(path_data$cancer, as.character(cancer_ordering)) +teacher_line$cancer <- factor(teacher_line$cancer, as.character(cancer_ordering)) +path_data_summarised <- path_data %>% + group_by(cancer, model_type, lambda) %>% + summarise(mean = mean(value), sd = sd(value) / sqrt(n())) + +path_data_summarised$cancer <- factor(path_data_summarised$cancer, levels = cancer_ordering) +g <- ggplot(path_data_summarised, aes(x = as.numeric(lambda), group = model_type)) + + geom_line(aes(y = mean, color = model_type), linewidth = 1) + + geom_ribbon(aes(y = mean, ymin = mean - sd, ymax = mean + sd, fill = model_type), alpha = .1) + + scale_color_manual(values = friendly_pals$ito_seven[c(1, 6)]) + + scale_fill_manual(values = friendly_pals$ito_seven[c(1, 6)]) + + geom_hline(data = teacher_line, aes(yintercept = mean, linetype = "Cox-Nnet teacher"), color = "red", lwd = 0.5, linetype = 2, show.legend = FALSE, alpha = 0.75) + + facet_wrap(~cancer, scales = "free_y", nrow = 2) + + theme_big_simple() + + labs(x = "Regularization index (from sparse to dense)", y = "Antolini's C", fill = "", color = "") + + + +p <- ggplot(mtcars, aes(x = wt, y = mpg)) + + geom_point() +teacher_legend <- p + geom_hline(aes(lty = "Cox-Nnet teacher teacher", yintercept = 20), linewidth = 1, color = "red", show_guide = TRUE) + scale_linetype_manual(name = "", values = 2) + theme_big_simple() + guides(color = guide_legend(override.aes = list(linetype = c("dashed")))) + theme(legend.key.width = unit(2, "cm")) + + +metrics <- vroom::vroom(paste("results", "metrics", "metrics_overall.csv", sep = "/")) + +metrics %>% + filter(score %in% c("path")) %>% + filter((metric == "IBS" & model == "breslow" & !kd) | (kd & model == "cox_nnet" & metric == "IBS")) -> path_data +teacher_line <- metrics %>% + filter(score %in% c("teacher")) %>% + filter(metric == "IBS") %>% + filter(model == "cox_nnet") %>% + group_by(cancer) %>% + summarise(mean = mean(value)) + + +path_data$model_type <- ifelse(path_data$kd, "KD Cox-Nnet", + "glmnet (Breslow)" +) + + +path_data$model_type <- factor(path_data$model_type, levels = c("glmnet (Breslow)", "KD Cox-Nnet")) +path_data$cancer <- factor(path_data$cancer, levels = as.character(cancer_ordering)) +teacher_line$cancer <- factor(teacher_line$cancer, as.character(cancer_ordering)) +path_data_summarised <- path_data %>% + group_by(cancer, model_type, lambda) %>% + summarise(mean = mean(value), sd = sd(value) / sqrt(n())) + + +h <- ggplot(path_data_summarised, aes(x = as.numeric(lambda), group = model_type)) + + geom_line(aes(y = mean, color = model_type), linewidth = 1) + + geom_ribbon(aes(y = mean, ymin = mean - sd, ymax = mean + sd, fill = model_type), alpha = .1) + + scale_color_manual(values = friendly_pals$ito_seven[c(1, 6)]) + + scale_fill_manual(values = friendly_pals$ito_seven[c(1, 6)]) + + geom_hline(data = teacher_line, aes(yintercept = mean, linetype = "Cox-Nnet teacher"), color = "red", lwd = 0.5, linetype = 2, show.legend = FALSE, alpha = 0.75) + + facet_wrap(~cancer, scales = "free_y", nrow = 2) + + theme_big_simple() + + labs(x = "Regularization index (from sparse to dense)", y = "Integrated Brier Score", fill = "", color = "") + + +line_legend <- get_legend( + g + theme(legend.box.margin = margin(0, 0, 0, 0)) +) + +teacher_legend <- get_legend( + teacher_legend + theme(legend.box.margin = margin(0, 0, 0, 0)) +) + +both_legends <- plot_grid( + line_legend, teacher_legend +) + +reg_path <- plot_grid( + cowplot::plot_grid(g + theme(legend.position = "none"), both_legends, rel_heights = c(0.95, 0.1), nrow = 2, ncol = 1), + cowplot::plot_grid(h + theme(legend.position = "none"), both_legends, rel_heights = c(0.95, 0.1), nrow = 2, ncol = 1), + labels = c("A", "B"), + nrow = 2, + label_size = 24 +) + +ggsave(paste("results", "figures", "fig-S2_finalized.pdf", sep = "/"), plot = reg_path, dpi = 300, height = 20 / 1.75, width = 15, units = "in") +ggsave(paste("results", "figures", "fig-S2_finalized.svg", sep = "/"), plot = reg_path, dpi = 300, height = 20 / 1.75, width = 15, units = "in") diff --git a/workflow/scripts/r/plot_figure_S3.R b/workflow/scripts/r/plot_figure_S3.R new file mode 100644 index 0000000..4201db0 --- /dev/null +++ b/workflow/scripts/r/plot_figure_S3.R @@ -0,0 +1,122 @@ +log <- file(snakemake@log[[1]], open = "wt") +sink(log) + +suppressPackageStartupMessages({ + library(ggplot2) + library(dplyr) + library(vroom) + library(cowplot) + library(ggsignif) + options(warn = 1) +}) + +config <- rjson::fromJSON( + file = snakemake@params[["config_path"]] +) + +theme_big_simple <- function() { + theme_bw(base_size = 16, base_family = "") %+replace% + theme( + plot.background = element_rect(fill = "transparent", colour = NA), + legend.background = element_rect(fill = "transparent", colour = NA), + legend.key = element_rect(fill = "transparent", colour = NA), + legend.title = element_text(size = 24), + legend.text = element_text(size = 20), + axis.line = element_line(color = "black", size = 1, linetype = "solid"), + axis.ticks = element_line(colour = "black", size = 1), + panel.background = element_blank(), + panel.grid.minor = element_blank(), + panel.grid.major = element_blank(), + panel.border = element_blank(), + legend.position = "bottom", + plot.title = element_text(size = 24, hjust = 0.0, vjust = 1.75), + axis.text.x = element_text(color = "black", size = 20, margin = margin(t = 4, r = 0, b = 0, l = 0)), + axis.text.y = element_text(color = "black", size = 20, margin = margin(t = 0, r = 4, b = 0, l = 0)), + axis.title.y = element_text(margin = margin(t = 0, r = 10, b = 0, l = 0), angle = 90, size = 24), + axis.title.x = element_text(margin = margin(t = 10, r = 0, b = 0, l = 0), angle = 0, size = 24), + axis.ticks.length = unit(0.20, "cm"), + strip.background = element_rect(color = "black", size = 1, linetype = "solid"), + strip.text.x = element_text(size = 20, color = "black"), + strip.text.y = element_text(size = 20, color = "black") + ) +} + +friendly_pals <- list( + bright_seven = c("#4477AA", "#228833", "#AA3377", "#BBBBBB", "#66CCEE", "#CCBB44", "#EE6677"), + contrast_three = c("#004488", "#BB5566", "#DDAA33"), + vibrant_seven = c("#0077BB", "#EE7733", "#33BBEE", "#CC3311", "#009988", "#EE3377", "#BBBBBB"), + muted_nine = c("#332288", "#117733", "#CC6677", "#88CCEE", "#999933", "#882255", "#44AA99", "#DDCC77", "#AA4499"), + nickel_five = c("#648FFF", "#FE6100", "#785EF0", "#FFB000", "#DC267F"), + ito_seven = c("#0072B2", "#D55E00", "#009E73", "#CC79A7", "#56B4E9", "#E69F00", "#F0E442"), + ibm_five = c("#648FFF", "#785EF0", "#DC267F", "#FE6100", "#FFB000"), + wong_eight = c("#E69F00", "#56B4E9", "#009E73", "#F0E442", "#0072B2", "#D55E00", "#CC79A7", "#000000"), + tol_eight = c("#332288", "#117733", "#44AA99", "#88CCEE", "#DDCC77", "#CC6677", "#AA4499", "#882255"), + zesty_four = c("#F5793A", "#A95AA1", "#85C0F9", "#0F2080"), + retro_four = c("#601A4A", "#EE442F", "#63ACBE", "#F9F4EC") +) + +metrics <- vroom::vroom(paste("results", "metrics", "metrics_overall.csv", sep = "/")) +metrics_125 <- vroom::vroom(paste("results", "metrics", "metrics_overall_125_full.csv", sep = "/")) +metrics_stratified <- vroom::vroom(paste("results", "metrics", "metrics_overall_cved.csv", sep = "/")) + +metrics <- rbind( + cbind(metrics %>% filter(model == "breslow" & lambda %in% c("lambda.min", "min")), + calc_type = "5-fold CV 5 reps (per split)" + ), + cbind(metrics_125, calc_type = "5-fold CV 25 reps (per split)"), + cbind(metrics_stratified, calc_type = "5-fold CV 25 reps (per CV)") +) + +metrics$model_type <- ifelse( + metrics$kd, "KD Breslow (min)", + ifelse( + metrics$tuned, "glmnet tuned (Breslow)", + "glmnet (Breslow)" + ) +) + +a <- metrics %>% + filter(metric == "Harrell's C") %>% + ggplot(aes(x = model_type, y = value, fill = model_type)) + + geom_boxplot() + + theme_big_simple() + + labs(x = "", y = "Harrell's C", fill = "") + + facet_wrap(~ interaction(calc_type)) + + theme( + axis.title.x = element_blank(), + axis.text.x = element_blank(), + axis.ticks.x = element_blank() + ) + + scale_fill_manual(values = friendly_pals$ito_seven[c(1, 2, 4, 5, 6, 7)]) + +metrics$model_type <- ifelse( + metrics$kd, "KD Breslow (min)", + ifelse( + metrics$tuned, "glmnet tuned (Breslow)", + "glmnet (Breslow)" + ) +) + +b <- metrics %>% + filter(metric == "Uno's C") %>% + ggplot(aes(x = model_type, y = value, fill = model_type)) + + geom_boxplot() + + theme_big_simple() + + labs(x = "", y = "Harrell's C", fill = "") + + facet_wrap(~ interaction(calc_type)) + + theme( + axis.title.x = element_blank(), + axis.text.x = element_blank(), + axis.ticks.x = element_blank() + ) + + scale_fill_manual(values = friendly_pals$ito_seven[c(1, 2, 4, 5, 6, 7)]) + + +cv_fig <- cowplot::plot_grid(a, b, + labels = c("A", "B"), + nrow = 2, + label_size = 24 +) + +ggsave(paste("results", "figures", "fig-S3_finalized.pdf", sep = "/"), plot = cv_fig, dpi = 300, height = 20 / 1.5, width = 15, units = "in") +ggsave(paste("results", "figures", "fig-S3_finalized.svg", sep = "/"), plot = cv_fig, dpi = 300, height = 20 / 1.5, width = 15, units = "in") diff --git a/paper/scripts/r/run_preprocessing.R b/workflow/scripts/r/preprocess_data.R similarity index 60% rename from paper/scripts/r/run_preprocessing.R rename to workflow/scripts/r/preprocess_data.R index 9964c47..6e6e9d3 100644 --- a/paper/scripts/r/run_preprocessing.R +++ b/workflow/scripts/r/preprocess_data.R @@ -1,21 +1,3 @@ -suppressPackageStartupMessages({ - library(dplyr) - library(readr) - library(rjson) - library(fastDummies) - library(tidyr) - library(janitor) - library(forcats) - library(stringr) - library(rjson) - library(tibble) - library(vroom) - library(here) - library(readr) - library(vroom) - library(readxl) -}) - #' Chooses the proper sample in case donors have multiple primary samples. #' We first chose the "lower" vial (i.e., we pick vial A over vial B, since #' the lower vials tend to be much more common). Afterward, if there are still @@ -164,94 +146,24 @@ prepare_clinical_data <- function(clinical_raw, clinical_ext_raw, cancer) { return(clinical) } -#' Performs complete preprocessing of TCGA-PANCANATLAS CNV data. -#' -#' @param cnv data.frame. data.frame containing CNV data to be preprocessed. -#' @returns data.frame. Preprocessed data.frame. -prepare_cnv <- function(cnv) { - rownames(cnv) <- cnv[, 1] - cnv <- cnv[, 2:ncol(cnv)] - cnv <- preprocess(cnv, log = FALSE) - return(cnv) -} - -#' Performs complete preprocessing of TCGA-PANCANATLAS DNA methylation data. -#' -#' @param meth data.frame. data.frame containing DNA methylation data to be preprocessed. -#' @returns data.frame. Preprocessed data.frame. -prepare_meth_pancan <- function(meth) { - rownames(meth) <- meth[, 1] - meth <- meth[, 2:ncol(meth)] - meth <- preprocess(meth, log = FALSE) - return(meth) -} -#' Performs complete preprocessing of TCGA-PANCANATLAS mutation data. -#' -#' @param mut data.frame. data.frame containing mutation data to be preprocessed. -#' @returns data.frame. Preprocessed data.frame. -prepare_mutation <- function(mut) { - mut <- preprocess(mut, log = FALSE) - mut -} -#' Performs complete preprocessing of TCGA-PANCANATLAS protein expression data. -#' -#' @param rppa data.frame. data.frame containing protein expression data to be preprocessed. -#' @returns data.frame. Preprocessed data.frame. -prepare_rppa_pancan <- function(rppa) { - rownames(rppa) <- rppa[, 1] - rppa <- t(rppa[, 2:ncol(rppa)]) - rppa <- preprocess(rppa, log = FALSE) - rppa -} - -#' Performs complete preprocessing of TCGA-PANCANATLAS miRNA data. -#' -#' @param rppa data.frame. data.frame containing miRNA data to be preprocessed. -#' @returns data.frame. Preprocessed data.frame. -prepare_mirna_pancan <- function(mirna) { - rownames(mirna) <- mirna[, 1] - mirna <- mirna[, 2:ncol(mirna)] - mrina <- preprocess(mirna, log = TRUE) -} - -#' Helper function to perform complete preprocessing for TCGA datasets. Writes -#' datasets directly to disk, separated by complete and missing modality samples. -#' -#' @param cancer character. Cancer dataset to be prepared. -#' @param include_rppa logical. Whether the dataset should contain RPPA data. -#' @param include_mirna logical. Whether the dataset should contain miRNA data. -#' @param include_mutation logical. Whether the dataset should contain mutation data. -#' @param include_methylation logical. Whether the dataset should contain DNA methylation data. -#' @param include_gex logical. Whether the dataset should contain mRNA data. -#' @param include_cnv logical. Whether the dataset should contain CNV data. -#' @returns NULL. -prepare_new_cancer_dataset <- function(cancer, - tcga_cdr_master, - tcga_w_followup_master, - gex_master) { - config <- rjson::fromJSON( - file = here::here("config.json") - ) - # Preprocess modalities one after the other taking input parameters into account. - # NOTE: We separate complete and missing data modalities by collecting the barcodes - # per each modality before we append the missing modality samples. - # NB: Appending the missing modality samples is necessary such that we still - # have the same features/modalities for all samples in the missing - # modality samples (even if some modalities are completely absent for some samples). +preprocess_data <- function(cancer, + tcga_cdr_master, + tcga_w_followup_master, + gex_master, + output_path) { clinical <- prepare_clinical_data(tcga_cdr_master, tcga_w_followup_master, cancer = cancer) sample_barcodes <- list(clinical$patient_id) - patients <- unname(unlist(sapply(clinical$patient_id, function(x) grep(x, colnames(gex_master))))) - gex_filtered <- gex_master[, c(1, patients)] - gex <- prepare_gene_expression_pancan(gex_filtered) - sample_barcodes <- append(sample_barcodes, list(colnames(gex))) + patients <- unname(unlist(sapply(clinical$patient_id, function(x) grep(x, colnames(gex_master))))) + gex_filtered <- gex_master[, c(1, patients)] + gex <- prepare_gene_expression_pancan(gex_filtered) + sample_barcodes <- append(sample_barcodes, list(colnames(gex))) - # Get set of common (complete) samples and write it to disk. common_samples <- Reduce(intersect, sample_barcodes) data <- clinical %>% filter(patient_id %in% common_samples) %>% - arrange(desc(patient_id)) #%>% + arrange(desc(patient_id)) data <- data %>% cbind( data.frame(t(gex), check.names = FALSE) %>% @@ -261,54 +173,58 @@ prepare_new_cancer_dataset <- function(cancer, dplyr::select(-rowname) %>% rename_with(function(x) paste0("gex_", x)) ) - + print(paste0("Writing: ", cancer)) data %>% - # Rename to `OS_days` for consistency with other projects/datasets. rename(OS_days = OS.time) %>% write_csv( - here::here("data", "processed", "TCGA", paste0(cancer, "_data_preprocessed.csv")) + output_path ) - return(NULL) } -#' Helper function to rerun our complete preprocessing in R. -#' -#' @returns NULL. All preprocessed datasets are directly written to disk. -rerun_preprocessing_R <- function() { - config <- rjson::fromJSON( - file = here::here("config.json") - ) +rerun_preprocessing_R <- function(gex_path, cdr_path, followup_path, cancer, output_path) { # Increase VROOM connection size for larger PANCAN files. Sys.setenv("VROOM_CONNECTION_SIZE" = 131072 * 8) - # Read in all PANCAN files. gex_master <- vroom( - here::here( - "data", "raw", "EBPlusPlusAdjustPANCAN_IlluminaHiSeq_RNASeqV2.geneExp.tsv" - ) + gex_path ) %>% data.frame(check.names = FALSE) tcga_cdr <- readxl::read_xlsx( - here::here("data", "raw", "TCGA-CDR-SupplementalTableS1.xlsx"), + cdr_path, guess_max = 2500, range = cell_cols("B:AH") ) tcga_w_followup <- read_tsv( - here::here( - "data", "raw", "clinical_PANCAN_patient_with_followup.tsv" - ), + followup_path, guess_max = 1e5 ) - for (cancer in config$datasets) { - prepare_new_cancer_dataset( - cancer = cancer, - tcga_cdr_master = tcga_cdr, - tcga_w_followup_master = tcga_w_followup, - gex_master = gex_master - ) - } - return(NULL) + preprocess_data( + cancer = cancer, + tcga_cdr_master = tcga_cdr, + tcga_w_followup_master = tcga_w_followup, + gex_master = gex_master, + output_path = output_path + ) + return(0) } -rerun_preprocessing_R() +log <- file(snakemake@log[[1]], open = "wt") +sink(log) +suppressPackageStartupMessages({ + library(dplyr) + library(readr) + library(tidyr) + library(stringr) + library(vroom) + library(readxl) + library(tibble) +}) + +rerun_preprocessing_R( + gex_path = snakemake@input[["gex"]], + cdr_path = snakemake@input[["cdr"]], + followup_path = snakemake@input[["followup"]], + cancer = snakemake@params[["cancer"]], + output_path = snakemake@output[["output_path"]] +) diff --git a/paper/scripts/r/run_glmnet.R b/workflow/scripts/r/run_glmnet.R similarity index 71% rename from paper/scripts/r/run_glmnet.R rename to workflow/scripts/r/run_glmnet.R index 4426beb..a6b9f90 100644 --- a/paper/scripts/r/run_glmnet.R +++ b/workflow/scripts/r/run_glmnet.R @@ -1,31 +1,33 @@ -library(rjson) -library(glmnet) -library(survival) -library(coefplot) -library(pec) -library(readr) -library(vroom) -library(dplyr) -library(splitTools) -library(glmnetUtils) - +log <- file(snakemake@log[[1]], open = "wt") +sink(log) + +suppressPackageStartupMessages({ + library(glmnet) + library(survival) + library(coefplot) + library(pec) + library(readr) + library(vroom) + library(dplyr) + library(splitTools) + library(glmnetUtils) +}) glmnet::glmnet.control( - fdev = 0, - devmax = 1.0 + fdev = 0, + devmax = 1.0 ) - config <- rjson::fromJSON( - file = here::here( - "config.json" - ) + file = snakemake@params[["config_path"]] ) # Get alpha. get_alpha <- function(fit) { alpha <- fit$alpha - error <- sapply(fit$modlist, function(mod) {min(mod$cvm)}) + error <- sapply(fit$modlist, function(mod) { + min(mod$cvm) + }) alpha[which.min(error)] } @@ -53,26 +55,29 @@ for (tune_l1_ratio in c(FALSE)) { result_sparsity <- c() lp_df <- list() data <- data.frame(vroom::vroom( - here::here( - "data", "processed", "TCGA", - paste0(cancer, "_data_preprocessed.csv") + paste( + "results", "preprocess_data", + paste0(cancer, ".csv"), + sep = "/" ) )[, -1], check.names = FALSE) train_splits <- data.frame(vroom::vroom( - here::here( - "data", "splits", "TCGA", - paste0(cancer, "_train_splits.csv") + paste( + "results", "make_splits", + paste0(cancer, "_train_splits.csv"), + sep = "/" ) ), check.names = FALSE) test_splits <- data.frame(vroom::vroom( - here::here( - "data", "splits", "TCGA", - paste0(cancer, "_test_splits.csv") + paste( + "results", "make_splits", + paste0(cancer, "_test_splits.csv"), + sep = "/" ) ), check.names = FALSE) - for (split in 1:(config$n_outer_splits * config$n_outer_repetitions)) { + for (split in 1:125) { train_ix <- as.numeric(unname(train_splits[split, ])) train_ix <- train_ix[!is.na(train_ix)] + 1 @@ -99,21 +104,20 @@ for (tune_l1_ratio in c(FALSE)) { { if (tune_l1_ratio) { fit <- cva.glmnet( - x = as.matrix(X_train), - y = y_train, - family = "cox", - alpha = config$l1_ratio_tuned, - lambda.min.ratio = config$eps, - standardize = TRUE, - nlambda = config$n_alphas, - nfolds = config$n_inner_cv, - grouped = score == "vvh", - foldid = fold_ids - ) - + x = as.matrix(X_train), + y = y_train, + family = "cox", + alpha = config$l1_ratio_tuned, + lambda.min.ratio = config$eps, + standardize = TRUE, + nlambda = config$n_alphas, + nfolds = config$n_inner_cv, + grouped = score == "vvh", + foldid = fold_ids + ) + fit <- fit$modlist[[which((fit$alpha == get_alpha(fit)))]] - } - else { + } else { fit <- cv.glmnet( x = as.matrix(X_train), y = y_train, @@ -130,8 +134,7 @@ for (tune_l1_ratio in c(FALSE)) { if (cv_score %in% c("lambda.min", "lambda.1se")) { n_sparsity <- nrow(extract.coef(fit, cv_score)) - print(n_sparsity) - if (n_sparsity == 0) { + if (n_sparsity == 0) { stop() } linear_predictor <- as.vector(predict(fit, as.matrix(X_test), s = fit$lambda.min)) @@ -168,7 +171,6 @@ for (tune_l1_ratio in c(FALSE)) { list(sparsity = n_sparsity, linear_predictor = linear_predictor, surv = surv, failures = 0) }, error = function(cond) { - print(cond) times <- sort(unique(y_test[, 1])) km <- exp(-survfit(y_test ~ 1)$cumhaz) km_surv <- matrix(rep(km, nrow(X_test)), nrow = nrow(X_test), byrow = TRUE) @@ -178,17 +180,18 @@ for (tune_l1_ratio in c(FALSE)) { ) if (tune_l1_ratio) { data.frame(result$surv, check.names = FALSE) %>% readr::write_csv( - here::here( - "results", "non_kd", "breslow", cancer, paste0("survival_function_tuned_l1_ratio_", score, "_", cv_score, "_", split, ".csv") + paste( + "results", "non_kd", "breslow", cancer, paste0("survival_function_tuned_l1_ratio_", score, "_", cv_score, "_", split, ".csv"), + sep = "/" ) - ) - } - else { + ) + } else { data.frame(result$surv, check.names = FALSE) %>% readr::write_csv( - here::here( - "results", "non_kd", "breslow", cancer, paste0("survival_function_", score, "_", cv_score, "_", split, ".csv") + paste( + "results", "non_kd", "breslow", cancer, paste0("survival_function_", score, "_", cv_score, "_", split, ".csv"), + sep = "/" ) - ) + ) } n_failures <- n_failures + result$failures @@ -201,50 +204,52 @@ for (tune_l1_ratio in c(FALSE)) { colnames(lp_df) <- 1:ncol(lp_df) if (tune_l1_ratio) { lp_df %>% - write.csv( - here::here( - "results", "non_kd", "breslow", cancer, paste0("eta_tuned_l1_ratio_", score, "_", cv_score, ".csv") - ) + write.csv( + paste( + "results", "non_kd", "breslow", cancer, paste0("eta_tuned_l1_ratio_", score, "_", cv_score, ".csv"), + sep = "/" ) - } - else { + ) + } else { lp_df %>% - write.csv( - here::here( - "results", "non_kd", "breslow", cancer, paste0("eta_", score, "_", cv_score, ".csv") - ) + write.csv( + paste( + "results", "non_kd", "breslow", cancer, paste0("eta_", score, "_", cv_score, ".csv"), + sep = "/" ) + ) } - - } if (tune_l1_ratio) { - data.frame(failures) %>% write_csv( - here::here( - "results", "non_kd", "breslow", paste0("failures_tuned_l1_ratio_", score, "_", cv_score, ".csv") - ) + data.frame(failures) %>% write_csv( + paste( + "results", "non_kd", "breslow", paste0("failures_tuned_l1_ratio_", score, "_", cv_score, ".csv"), + sep = "/" ) + ) - data.frame(sparsity) %>% write_csv( - here::here( - "results", "non_kd", "breslow", paste0("sparsity_tuned_l1_ratio_", score, "_", cv_score, ".csv") - ) + data.frame(sparsity) %>% write_csv( + paste( + "results", "non_kd", "breslow", paste0("sparsity_tuned_l1_ratio_", score, "_", cv_score, ".csv"), + sep = "/" ) - } - else { - data.frame(failures) %>% write_csv( - here::here( - "results", "non_kd", "breslow", paste0("failures_", score, "_", cv_score, ".csv") - ) + ) + } else { + data.frame(failures) %>% write_csv( + paste( + "results", "non_kd", "breslow", paste0("failures_", score, "_", cv_score, ".csv"), + sep = "/" ) + ) - data.frame(sparsity) %>% write_csv( - here::here( - "results", "non_kd", "breslow", paste0("sparsity_", score, "_", cv_score, ".csv") - ) + data.frame(sparsity) %>% write_csv( + paste( + "results", "non_kd", "breslow", paste0("sparsity_", score, "_", cv_score, ".csv"), + sep = "/" ) - } + ) + } } } } diff --git a/workflow/scripts/r/run_glmnet_tuned.R b/workflow/scripts/r/run_glmnet_tuned.R new file mode 100644 index 0000000..860904c --- /dev/null +++ b/workflow/scripts/r/run_glmnet_tuned.R @@ -0,0 +1,257 @@ +log <- file(snakemake@log[[1]], open = "wt") +sink(log) + +suppressPackageStartupMessages({ + library(glmnet) + library(survival) + library(coefplot) + library(pec) + library(readr) + library(vroom) + library(dplyr) + library(splitTools) + library(glmnetUtils) + library(parallel) + library(doParallel) +}) + +glmnet::glmnet.control( + fdev = 0, + devmax = 1.0 +) + +config <- rjson::fromJSON( + file = snakemake@params[["config_path"]] +) + +# Get alpha. +get_alpha <- function(fit) { + alpha <- fit$alpha + error <- sapply(fit$modlist, function(mod) { + min(mod$cvm) + }) + alpha[which.min(error)] +} + +set.seed(config$seed) + +# https://stackoverflow.com/questions/7196450/create-a-dataframe-of-unequal-lengths +na.pad <- function(x, len) { + x[1:len] +} + +makePaddedDataFrame <- function(l, ...) { + maxlen <- max(sapply(l, length)) + data.frame(lapply(l, na.pad, len = maxlen), ...) +} + +for (tune_l1_ratio in c(TRUE)) { + for (score in c("vvh")) { + for (cv_score in c("lambda.min")) { + n_failures <- 0 + failures <- list() + sparsity <- list() + + for (cancer in snakemake@params[["cancer"]]) { + result_sparsity <- c() + lp_df <- list() + data <- data.frame(vroom::vroom( + paste( + "results", "preprocess_data", + paste0(cancer, ".csv"), + sep = "/" + ) + )[, -1], check.names = FALSE) + train_splits <- data.frame(vroom::vroom( + paste( + "results", "make_splits", + paste0(cancer, "_train_splits.csv"), + sep = "/" + ) + ), check.names = FALSE) + test_splits <- data.frame(vroom::vroom( + paste( + "results", "make_splits", + paste0(cancer, "_test_splits.csv"), + sep = "/" + ) + ), check.names = FALSE) + + + for (split in 1:125) { + train_ix <- as.numeric(unname(train_splits[split, ])) + train_ix <- train_ix[!is.na(train_ix)] + 1 + + test_ix <- as.numeric(unname(test_splits[split, ])) + test_ix <- test_ix[!is.na(test_ix)] + 1 + + X_train <- data[train_ix, -(1:2)] + X_test <- data[test_ix, -(1:2)] + y_train <- Surv(data$OS_days[train_ix], data$OS[train_ix]) + y_test <- Surv(data$OS_days[test_ix], data$OS[test_ix]) + fold_ids <- rep(0, length(y_train)) + + fold_helper <- create_folds( + y = data$OS[train_ix], + k = config$n_inner_cv, + type = c("stratified"), + invert = TRUE, + seed = config$seed + ) + for (i in 1:length(fold_helper)) { + fold_ids[fold_helper[[i]]] <- i + } + result <- tryCatch( + { + if (tune_l1_ratio) { + fit <- cva.glmnet( + x = as.matrix(X_train), + y = y_train, + family = "cox", + alpha = config$l1_ratio_tuned, + lambda.min.ratio = config$eps, + standardize = TRUE, + nlambda = config$n_alphas, + nfolds = config$n_inner_cv, + grouped = score == "vvh", + foldid = fold_ids # , + # outerParallel = {inner_cl <- parallel::makeForkCluster(7); parallel::clusterSetRNGStream(inner_cl, config$seed); inner_cl} + ) + + fit <- fit$modlist[[which((fit$alpha == get_alpha(fit)))]] + } else { + fit <- cv.glmnet( + x = as.matrix(X_train), + y = y_train, + family = "cox", + alpha = config$l1_ratio, + lambda.min.ratio = config$eps, + standardize = TRUE, + nlambda = config$n_alphas, + nfolds = config$n_inner_cv, + grouped = score == "vvh", + foldid = fold_ids + ) + } + + if (cv_score %in% c("lambda.min", "lambda.1se")) { + n_sparsity <- nrow(extract.coef(fit, cv_score)) + if (n_sparsity == 0) { + stop() + } + linear_predictor <- as.vector(predict(fit, as.matrix(X_test), s = fit$lambda.min)) + X_train_survival <- cbind(data$OS_days[train_ix], data$OS[train_ix], X_train[, sapply(rownames(extract.coef(fit, cv_score)), function(x) grep(x, colnames(X_test), fixed = TRUE)), drop = FALSE]) + X_test_survival <- X_test[, sapply(rownames(extract.coef(fit, cv_score)), function(x) grep(x, colnames(X_test), fixed = TRUE)), drop = FALSE] + } else { + if (which(fit$lambda.min == fit$lambda) == 1) { + stop() + } + transformed_error_space <- fit$cvm - (((fit$lambda[which(fit$lambda == fit$lambda.min)] - fit$cvm[1]) / (fit$nzero[which(fit$lambda == fit$lambda.min)])) * fit$nzero) + lambda_ix <- which.min(transformed_error_space[1:which(fit$lambda == fit$lambda.min)]) + coefs <- fit$glmnet.fit$beta[, lambda_ix] + coefs <- coefs[coefs != 0.0] + n_sparsity <- length(coefs) + if (n_sparsity == 0) { + stop() + } + + linear_predictor <- as.vector(as.matrix(X_test) %*% as.matrix(fit$glmnet.fit$beta[, lambda_ix])) + X_train_survival <- cbind(data$OS_days[train_ix], data$OS[train_ix], X_train[, sapply(names(coefs), function(x) grep(x, colnames(X_test), fixed = TRUE)), drop = FALSE]) + X_test_survival <- X_test[, sapply(names(coefs), function(x) grep(x, colnames(X_test), fixed = TRUE)), drop = FALSE] + } + + colnames(X_train_survival)[1:2] <- c("time", "event") + + cox_helper <- coxph(Surv(time, event) ~ ., data = X_train_survival, ties = "breslow", init = extract.coef(fit, cv_score)[, 1], iter.max = 0, x = TRUE) + surv <- pec::predictSurvProb(cox_helper, X_test_survival, unique(sort(y_test[, 1]))) + if (length(which(is.na(surv[1, ]))) > 1) { + surv[, which(is.na(surv[1, ]))] <- matrix(rep(surv[, max(which(!is.na(surv[1, ])))], length(which(is.na(surv[1, ])))), ncol = length(which(is.na(surv[1, ])))) + } else { + surv[, which(is.na(surv[1, ]))] <- surv[, max(which(!is.na(surv[1, ])))] + } + colnames(surv) <- unique(sort(y_test[, 1])) + list(sparsity = n_sparsity, linear_predictor = linear_predictor, surv = surv, failures = 0) + }, + error = function(cond) { + times <- sort(unique(y_test[, 1])) + km <- exp(-survfit(y_test ~ 1)$cumhaz) + km_surv <- matrix(rep(km, nrow(X_test)), nrow = nrow(X_test), byrow = TRUE) + colnames(km_surv) <- times + return(list(sparsity = 0, linear_predictor = rep(0, nrow(X_test)), surv = km_surv, failures = 1)) + } + ) + if (tune_l1_ratio) { + data.frame(result$surv, check.names = FALSE) %>% readr::write_csv( + paste( + "results", "non_kd", "breslow", cancer, paste0("survival_function_tuned_l1_ratio_", score, "_", cv_score, "_", split, ".csv"), + sep = "/" + ) + ) + } else { + data.frame(result$surv, check.names = FALSE) %>% readr::write_csv( + paste( + "results", "non_kd", "breslow", cancer, paste0("survival_function_", score, "_", cv_score, "_", split, ".csv"), + sep = "/" + ) + ) + } + + n_failures <- n_failures + result$failures + result_sparsity <- c(result_sparsity, result$sparsity) + lp_df[[split]] <- result$linear_predictor + } + failures[[cancer]] <- n_failures + sparsity[[cancer]] <- result_sparsity + lp_df <- makePaddedDataFrame(lp_df) + colnames(lp_df) <- 1:ncol(lp_df) + if (tune_l1_ratio) { + lp_df %>% + write.csv( + paste( + "results", "non_kd", "breslow", cancer, paste0("eta_tuned_l1_ratio_", score, "_", cv_score, ".csv"), + sep = "/" + ) + ) + } else { + lp_df %>% + write.csv( + paste( + "results", "non_kd", "breslow", cancer, paste0("eta_", score, "_", cv_score, ".csv"), + sep = "/" + ) + ) + } + } + + if (tune_l1_ratio) { + data.frame(failures) %>% write_csv( + paste( + "results", "non_kd", "breslow", cancer, paste0("failures_tuned_l1_ratio_", score, "_", cv_score, ".csv"), + sep = "/" + ) + ) + + data.frame(sparsity) %>% write_csv( + paste( + "results", "non_kd", "breslow", cancer, paste0("sparsity_tuned_l1_ratio_", score, "_", cv_score, ".csv"), + sep = "/" + ) + ) + } else { + data.frame(failures) %>% write_csv( + paste( + "results", "non_kd", "breslow", paste0("failures_", score, "_", cv_score, ".csv"), + sep = "/" + ) + ) + + data.frame(sparsity) %>% write_csv( + paste( + "results", "non_kd", "breslow", paste0("sparsity_", score, "_", cv_score, ".csv"), + sep = "/" + ) + ) + } + } + } +} diff --git a/paper/scripts/r/run_path_glmnet.R b/workflow/scripts/r/run_path_glmnet.R similarity index 51% rename from paper/scripts/r/run_path_glmnet.R rename to workflow/scripts/r/run_path_glmnet.R index d0e6e5b..49d90df 100644 --- a/paper/scripts/r/run_path_glmnet.R +++ b/workflow/scripts/r/run_path_glmnet.R @@ -1,28 +1,26 @@ -library(rjson) -library(glmnet) -library(survival) -library(coefplot) -library(pec) -library(readr) -library(vroom) -library(dplyr) -library(splitTools) - +log <- file(snakemake@log[[1]], open = "wt") +sink(log) + +suppressPackageStartupMessages({ + library(glmnet) + library(survival) + library(coefplot) + library(pec) + library(readr) + library(vroom) + library(dplyr) + library(splitTools) +}) glmnet::glmnet.control( - fdev = 0, - devmax = 1.0 + fdev = 0, + devmax = 1.0 ) config <- rjson::fromJSON( - file = here::here( - "config.json" - ) + file = snakemake@params[["config_path"]] ) -set.seed(config$seed) - - # https://stackoverflow.com/questions/7196450/create-a-dataframe-of-unequal-lengths na.pad <- function(x, len) { x[1:len] @@ -41,22 +39,13 @@ for (cancer in config$datasets) { n_failures <- 0 result_sparsity <- list() data <- data.frame(vroom::vroom( - here::here( - "data", "processed", "TCGA", - paste0(cancer, "_data_preprocessed.csv") - ) + paste0("results/preprocess_data/", cancer, ".csv") )[, -1], check.names = FALSE) train_splits <- data.frame(vroom::vroom( - here::here( - "data", "splits", "TCGA", - paste0(cancer, "_train_splits.csv") - ) + paste0("results/make_splits/", cancer, "_train_splits.csv") ), check.names = FALSE) test_splits <- data.frame(vroom::vroom( - here::here( - "data", "splits", "TCGA", - paste0(cancer, "_test_splits.csv") - ) + paste0("results/make_splits/", cancer, "_test_splits.csv") ), check.names = FALSE) @@ -97,50 +86,50 @@ for (cancer in config$datasets) { km_surv <- matrix(rep(km, nrow(X_test)), nrow = nrow(X_test), byrow = TRUE) colnames(km_surv) <- times data.frame(km_surv, check.names = FALSE) %>% readr::write_csv( - here::here( - "results", "non_kd", "breslow", cancer, "path", paste0("survival_function_", z, "_alpha_", split, ".csv") - ) - ) - } - else { - current_coef <- path_coefs[, z] - current_coef <- current_coef[current_coef != 0.0] - if (length(current_coef) == 0) { - result_sparsity[[split]] <- c(result_sparsity[[split]], 0) - times <- sort(unique(y_test[, 1])) - km <- exp(-survfit(y_test ~ 1)$cumhaz) - km_surv <- matrix(rep(km, nrow(X_test)), nrow = nrow(X_test), byrow = TRUE) - colnames(km_surv) <- times - data.frame(km_surv, check.names = FALSE) %>% readr::write_csv( - here::here( - "results", "non_kd", "breslow", cancer, "path", paste0("survival_function_", z, "_alpha_", split, ".csv") + paste( + "results", "non_kd", "breslow", cancer, "path", paste0("survival_function_", z, "_alpha_", split, ".csv"), + sep = "/" ) ) } else { - X_train_survival <- cbind(data$OS_days[train_ix], data$OS[train_ix], X_train[, sapply(names(current_coef), function(x) grep(x, colnames(X_test), fixed = TRUE)), drop = FALSE]) - X_test_survival <- X_test[, sapply(names(current_coef), function(x) grep(x, colnames(X_test), fixed = TRUE)), drop = FALSE] - colnames(X_train_survival)[1:2] <- c("time", "event") - cox_helper <- coxph(Surv(time, event) ~ ., data = X_train_survival, ties = "breslow", init = current_coef, iter.max = 0, x = TRUE) - surv <- pec::predictSurvProb(cox_helper, X_test_survival, unique(sort(y_test[, 1]))) - if (length(which(is.na(surv[1, ]))) > 1) { - surv[, which(is.na(surv[1, ]))] <- matrix(rep(surv[, max(which(!is.na(surv[1, ])))], length(which(is.na(surv[1, ])))), ncol = length(which(is.na(surv[1, ])))) + current_coef <- path_coefs[, z] + current_coef <- current_coef[current_coef != 0.0] + if (length(current_coef) == 0) { + result_sparsity[[split]] <- c(result_sparsity[[split]], 0) + times <- sort(unique(y_test[, 1])) + km <- exp(-survfit(y_test ~ 1)$cumhaz) + km_surv <- matrix(rep(km, nrow(X_test)), nrow = nrow(X_test), byrow = TRUE) + colnames(km_surv) <- times + data.frame(km_surv, check.names = FALSE) %>% readr::write_csv( + paste( + "results", "non_kd", "breslow", cancer, "path", paste0("survival_function_", z, "_alpha_", split, ".csv"), + sep = "/" + ) + ) } else { - surv[, which(is.na(surv[1, ]))] <- surv[, max(which(!is.na(surv[1, ])))] - } - colnames(surv) <- unique(sort(y_test[, 1])) - result_sparsity[[split]] <- c(result_sparsity[[split]], length(current_coef)) - data.frame(surv, check.names = FALSE) %>% readr::write_csv( - here::here( - "results", "non_kd", "breslow", cancer, "path", paste0("survival_function_", z, "_alpha_", split, ".csv") + X_train_survival <- cbind(data$OS_days[train_ix], data$OS[train_ix], X_train[, sapply(names(current_coef), function(x) grep(x, colnames(X_test), fixed = TRUE)), drop = FALSE]) + X_test_survival <- X_test[, sapply(names(current_coef), function(x) grep(x, colnames(X_test), fixed = TRUE)), drop = FALSE] + colnames(X_train_survival)[1:2] <- c("time", "event") + cox_helper <- coxph(Surv(time, event) ~ ., data = X_train_survival, ties = "breslow", init = current_coef, iter.max = 0, x = TRUE) + surv <- pec::predictSurvProb(cox_helper, X_test_survival, unique(sort(y_test[, 1]))) + if (length(which(is.na(surv[1, ]))) > 1) { + surv[, which(is.na(surv[1, ]))] <- matrix(rep(surv[, max(which(!is.na(surv[1, ])))], length(which(is.na(surv[1, ])))), ncol = length(which(is.na(surv[1, ])))) + } else { + surv[, which(is.na(surv[1, ]))] <- surv[, max(which(!is.na(surv[1, ])))] + } + colnames(surv) <- unique(sort(y_test[, 1])) + result_sparsity[[split]] <- c(result_sparsity[[split]], length(current_coef)) + data.frame(surv, check.names = FALSE) %>% readr::write_csv( + paste( + "results", "non_kd", "breslow", cancer, "path", paste0("survival_function_", z, "_alpha_", split, ".csv"), + sep = "/" + ) ) - ) - } + } } - } }, error = function(cond) { - print(cond) n_failures <- n_failures + 1 result_sparsity[[split]] <- vector() times <- sort(unique(y_test[, 1])) @@ -150,8 +139,9 @@ for (cancer in config$datasets) { for (z in 1:100) { result_sparsity[[split]] <- c(result_sparsity[[split]], 0) data.frame(km_surv, check.names = FALSE) %>% readr::write_csv( - here::here( - "results", "non_kd", "breslow", cancer, "path", paste0("survival_function_", z, "_alpha_", split, ".csv") + paste( + "results", "non_kd", "breslow", cancer, "path", paste0("survival_function_", z, "_alpha_", split, ".csv"), + sep = "/" ) ) } @@ -159,15 +149,17 @@ for (cancer in config$datasets) { ) } data.frame(result_sparsity) %>% write_csv( - here::here( - "results", "non_kd", "breslow", cancer, "path", "sparsity.csv" + paste( + "results", "non_kd", "breslow", cancer, "path", "sparsity.csv", + sep = "/" ) ) failures[[cancer]] <- n_failures } data.frame(failures) %>% write_csv( - here::here( - "results", "non_kd", "breslow", cancer, "path", "failures.csv" + paste( + "results", "non_kd", "breslow", cancer, "path", "failures.csv", + sep = "/" ) ) diff --git a/workflow/scripts/r/time_glmnet.R b/workflow/scripts/r/time_glmnet.R new file mode 100644 index 0000000..bf3130c --- /dev/null +++ b/workflow/scripts/r/time_glmnet.R @@ -0,0 +1,98 @@ +log <- file(snakemake@log[[1]], open = "wt") +sink(log) + +suppressPackageStartupMessages({ + library(glmnet) + library(survival) + library(readr) + library(vroom) + library(dplyr) + library(microbenchmark) + library(splitTools) + library(glmnetUtils) +}) + +# Prevent early stoping. +glmnet::glmnet.control( + fdev = 0, + devmax = 1.0 +) + +config <- rjson::fromJSON( + file = snakemake@params[["config_path"]] +) + +set.seed(config$seed) + +# https://stackoverflow.com/questions/7196450/create-a-dataframe-of-unequal-lengths +na.pad <- function(x, len) { + x[1:len] +} + +makePaddedDataFrame <- function(l, ...) { + maxlen <- max(sapply(l, length)) + data.frame(lapply(l, na.pad, len = maxlen), ...) +} + +timing <- list() + +for (tune_l1_ratio in c(TRUE, FALSE)) { + for (cancer in c(config$datasets)) { + timing[[cancer]] <- c() + data <- data.frame(vroom::vroom( + paste("results/preprocess_data/", paste0(cancer, ".csv"), sep = "/") + )[, -1], check.names = FALSE) + x <- as.matrix(data[, -(1:2)]) + y <- Surv(data$OS_days, data$OS) + fold_ids <- rep(0, length(y)) + + fold_helper <- create_folds( + y = data$OS, + k = config$n_inner_cv, + type = c("stratified"), + invert = TRUE, + seed = config$seed + ) + for (i in 1:length(fold_helper)) { + fold_ids[fold_helper[[i]]] <- i + } + if (tune_l1_ratio) { + tim <- microbenchmark( + cva.glmnet( + x = x, + y = y, + family = "cox", + alpha = config$l1_ratio_tuned, + lambda.min.ratio = config$eps, + standardize = TRUE, + nlambda = config$n_alphas, + foldid = fold_ids, + grouped = TRUE + ), + times = config$timing_reps + ) + } else { + tim <- microbenchmark( + cv.glmnet( + x = x, + y = y, + family = "cox", + alpha = config$l1_ratio, + lambda.min.ratio = config$eps, + standardize = TRUE, + nlambda = config$n_alphas, + foldid = fold_ids, + grouped = TRUE + ), + times = config$timing_reps + ) + } + + timing[[cancer]] <- tim$time * 1e-9 + } + if (tune_l1_ratio) { + data.frame(timing) %>% write_csv(paste("results", "non_kd", "breslow", "timing_tuned_l1_ratio.csv", sep = "/")) + } else { + data.frame(timing) %>% write_csv(paste("results", "non_kd", "breslow", "timing.csv", sep = "/")) + } +}