From 512b2538eeb448555830213cde61c689919000a3 Mon Sep 17 00:00:00 2001 From: ZekeMarshall Date: Wed, 16 Oct 2024 14:14:54 +0100 Subject: [PATCH] Reading model display from duckdb --- inst/app/app.R | 6 +- .../data_entry/surveyDataValidator_ui.R | 5 ++ .../niche_models/nmModelDisplay_server.R | 65 +++++++++++++++---- .../modules/niche_models/nmModelRun_server.R | 7 -- .../modules/niche_models/nmSidebar_server.R | 2 + inst/app/modules/niche_models/nmSidebar_ui.R | 44 ++++++++++++- 6 files changed, 106 insertions(+), 23 deletions(-) diff --git a/inst/app/app.R b/inst/app/app.R index 0381b92..5458393 100644 --- a/inst/app/app.R +++ b/inst/app/app.R @@ -51,6 +51,9 @@ library(mlr3pipelines) library(mlr3learners) library(mlr3extralearners) library(targets) +library(DBI) +library(dbplyr) +library(duckdb) library(qs) library(stats) library(DALEX) @@ -59,11 +62,12 @@ library(DALEXtra) source("./../../R/temp_functions.R", local = TRUE) source("./../../R/graph_functions.R", local = TRUE) tar_store <- file.path("C:/Users/zekmar/Github/GBIENMAnalysis/_targets") +db_path <- file.path("C:", "Users", "zekmar", "OneDrive - UKCEH", "GBIENMWorkingDir", "OutputData") + modelled_species <- targets::tar_read(name = "Species", store = tar_store) mlr3extralearners::install_learners(c("classif.gam", "classif.randomForest")) - # Render documentation ---------------------------------------------------- # rmarkdown::render(input = "./inst/app/docs/documentation.Rmd", output_dir = "./inst/app/www") diff --git a/inst/app/modules/data_entry/surveyDataValidator_ui.R b/inst/app/modules/data_entry/surveyDataValidator_ui.R index 40918cb..265cbea 100644 --- a/inst/app/modules/data_entry/surveyDataValidator_ui.R +++ b/inst/app/modules/data_entry/surveyDataValidator_ui.R @@ -27,6 +27,11 @@ surveyDataValidatorUI <- function(id){ shiny::div( shiny::actionButton(inputId = ns("reallocateGroups"), label = "Re-allocate Groups") + ), + + shiny::div( + shiny::actionButton(inputId = ns("trimWS"), + label = "Trim White Space") ) ), diff --git a/inst/app/modules/niche_models/nmModelDisplay_server.R b/inst/app/modules/niche_models/nmModelDisplay_server.R index 7f43a82..1f89eec 100644 --- a/inst/app/modules/niche_models/nmModelDisplay_server.R +++ b/inst/app/modules/niche_models/nmModelDisplay_server.R @@ -6,12 +6,14 @@ nmModelDisplay <- function(input, output, session, sidebar_nm_options) { focalSpecies <- reactiveVal() selectedModelDisplay <- reactiveVal() selectedVariablesDisplay <- reactiveVal() + selectedMarginalEffectsPlot <- reactiveVal() observe({ focalSpecies(sidebar_nm_options()$focalSpecies) selectedModelDisplay(sidebar_nm_options()$selectedModelDisplay) selectedVariablesDisplay(sidebar_nm_options()$selectedVariablesDisplay) + selectedMarginalEffectsPlot(sidebar_nm_options()$selectedMarginalEffectsPlot) }) |> bindEvent(sidebar_nm_options(), @@ -20,7 +22,7 @@ nmModelDisplay <- function(input, output, session, sidebar_nm_options) { # Retrieve Data ----------------------------------------------------------- measures_rval <- reactiveVal() - aleData_rval <- reactiveVal() + meData_rval <- reactiveVal() featureImportance_rval <- reactiveVal() observe({ @@ -34,26 +36,62 @@ nmModelDisplay <- function(input, output, session, sidebar_nm_options) { focalSpecies <- focalSpecies() selectedModelDisplay <- selectedModelDisplay() + selectedMarginalEffectsPlot <- selectedMarginalEffectsPlot() selectedVariablesDisplay <- c("_full_model_", "_baseline_", selectedVariablesDisplay()) - measures <- targets::tar_read(name = "AllMeasures", store = tar_store) |> - dplyr::filter(species == focalSpecies) |> - dplyr::filter(model %in% selectedModelDisplay) + # Open connection + con <- DBI::dbConnect(duckdb::duckdb(), + dbdir = file.path(db_path, "biens-db.duckdb"), + read_only = TRUE) - aleData <- targets::tar_read(name = "AllALEData", store = tar_store) |> + # Retrieve measures + measures <- dplyr::tbl(src = con, "AllMeasures") |> dplyr::filter(species == focalSpecies) |> dplyr::filter(model %in% selectedModelDisplay) |> - dplyr::filter(variable %in% selectedVariablesDisplay) + dplyr::collect() - featureImportance <- targets::tar_read(name = "AllFeatureImportance", store = tar_store) |> + # Retrieve marginal effects + if(selectedMarginalEffectsPlot == "ALE"){ + + meData <- dplyr::tbl(src = con, "AllALEData") |> + dplyr::filter(species == focalSpecies) |> + dplyr::filter(model %in% selectedModelDisplay) |> + dplyr::filter(variable %in% selectedVariablesDisplay) |> + dplyr::collect() + + } else if(selectedMarginalEffectsPlot == "PDP"){ + + meData <- dplyr::tbl(src = con, "AllPDPData") |> + dplyr::filter(species == focalSpecies) |> + dplyr::filter(model %in% selectedModelDisplay) |> + dplyr::filter(variable %in% selectedVariablesDisplay) |> + dplyr::collect() + + } else if(selectedMarginalEffectsPlot == "CP"){ + + meData <- dplyr::tbl(src = con, "AllCDData") |> + dplyr::filter(species == focalSpecies) |> + dplyr::filter(model %in% selectedModelDisplay) |> + dplyr::filter(variable %in% selectedVariablesDisplay) |> + dplyr::collect() + + } + + # Retrieve feature importance + featureImportance <- dplyr::tbl(src = con, "AllFeatureImportance") |> dplyr::filter(species == focalSpecies) |> dplyr::filter(model %in% selectedModelDisplay) |> - dplyr::filter(variable %in% selectedVariablesDisplay) + dplyr::filter(variable %in% selectedVariablesDisplay) |> + dplyr::collect() + + # Close connection + DBI::dbDisconnect(conn = con) - assign(x = "featureImportance", value = featureImportance, envir = .GlobalEnv) + # assign(x = "aleData", value = aleData, envir = .GlobalEnv) + # Store data in reactive objects measures_rval(measures) - aleData_rval(aleData) + meData_rval(meData) featureImportance_rval(featureImportance) # Stop busy spinner @@ -63,6 +101,7 @@ nmModelDisplay <- function(input, output, session, sidebar_nm_options) { bindEvent(focalSpecies(), selectedModelDisplay(), selectedVariablesDisplay(), + selectedMarginalEffectsPlot(), ignoreInit = TRUE) @@ -214,11 +253,11 @@ nmModelDisplay <- function(input, output, session, sidebar_nm_options) { text = "Rendering ALE Plot" ) - aleData <- aleData_rval() + meData <- meData_rval() output$ale_plot <- plotly::renderPlotly({ - ale_plot <- create_ale_plot(ale_data = aleData) + ale_plot <- create_ale_plot(ale_data = meData) ale_plotly <- plotly::ggplotly(ale_plot) @@ -231,7 +270,7 @@ nmModelDisplay <- function(input, output, session, sidebar_nm_options) { }) |> - bindEvent(aleData_rval(), + bindEvent(meData_rval(), ignoreInit = FALSE) diff --git a/inst/app/modules/niche_models/nmModelRun_server.R b/inst/app/modules/niche_models/nmModelRun_server.R index 58ea51c..cbbd0e7 100644 --- a/inst/app/modules/niche_models/nmModelRun_server.R +++ b/inst/app/modules/niche_models/nmModelRun_server.R @@ -81,11 +81,6 @@ nmModelRun <- function(input, output, session, sidebar_nm_options, nmDataInput) identifyPredDrivers <- identifyPredDrivers() }) - # assign(x = "predictorData", value = predictorData, envir = .GlobalEnv) - # assign(x = "selectedModel", value = selectedModel, envir = .GlobalEnv) - # assign(x = "selectedExplainer", value = selectedExplainer, envir = .GlobalEnv) - # assign(x = "identifyPredDrivers", value = identifyPredDrivers, envir = .GlobalEnv) - results <- predict(selectedModel, newdata = predictorData, predict_type = "prob") |> tibble::as_tibble() |> dplyr::bind_cols(dplyr::select(predictors, id)) |> @@ -93,8 +88,6 @@ nmModelRun <- function(input, output, session, sidebar_nm_options, nmDataInput) "Prob.Presence" = "Present", "Prob.Absence" = "Absent") - print(results) - # Identify model prediction drivers if(isTRUE(identifyPredDrivers)){ diff --git a/inst/app/modules/niche_models/nmSidebar_server.R b/inst/app/modules/niche_models/nmSidebar_server.R index 8682032..9eaacb2 100644 --- a/inst/app/modules/niche_models/nmSidebar_server.R +++ b/inst/app/modules/niche_models/nmSidebar_server.R @@ -12,6 +12,7 @@ nmSidebar <- function(input, output, session) { "focalSpecies" = input$focalSpecies, "selectedModelDisplay" = input$selectedModelDisplay, "selectedVariablesDisplay" = input$selectedVariablesDisplay, + "selectedMarginalEffectsPlot" = input$selectedMarginalEffectsPlot, "identifyPredDrivers" = input$identifyPredDrivers, "selectedModelPredict" = input$selectedModelPredict @@ -24,6 +25,7 @@ nmSidebar <- function(input, output, session) { input$focalSpecies, input$selectedModelDisplay, input$selectedVariablesDisplay, + input$selectedMarginalEffectsPlot, input$identifyPredDrivers, input$selectedModelPredict, ignoreInit = TRUE) diff --git a/inst/app/modules/niche_models/nmSidebar_ui.R b/inst/app/modules/niche_models/nmSidebar_ui.R index c8632f7..218dc1a 100644 --- a/inst/app/modules/niche_models/nmSidebar_ui.R +++ b/inst/app/modules/niche_models/nmSidebar_ui.R @@ -75,7 +75,7 @@ nmSidebarUI <- function(id){ col_widths = c(11, 1), shiny::selectizeInput(inputId = ns("selectedModelDisplay"), - label = "Selected Model", + label = "Model", choices = c("GAM", "NNet", "GLM", "RF", "MARS", "SVM", "XGB", "WE"), selected = c("GAM", "NNet", "GLM", "RF", "MARS", "SVM", "XGB", "WE"), multiple = TRUE), @@ -107,7 +107,7 @@ nmSidebarUI <- function(id){ col_widths = c(11, 1), shiny::selectizeInput(inputId = ns("selectedVariablesDisplay"), - label = "Selected Variables", + label = "Variables", choices = c("F", "L", "N", "R", "S", "DG", "DS", "H"), selected = c("F", "L", "N", "R", "S", "DG", "DS", "H"), multiple = TRUE), @@ -128,6 +128,46 @@ nmSidebarUI <- function(id){ shiny::div(shiny::br()) + ), + + shiny::div( + + id = ns("selectedMarginalEffectsPlot_div"), + + bslib::layout_columns( + + col_widths = c(11, 1), + + shiny::selectizeInput(inputId = ns("selectedMarginalEffectsPlot"), + label = "Marginal Effects Plot", + choices = c("ALE", "PDP", "CP"), + selected = "ALE", + multiple = FALSE), + + bslib::popover( + bsicons::bs_icon("info-circle"), + title = "Marginal Effects Plot Selection", + id = ns("selectedMarginalEffectsPlot_info"), + shiny::markdown( + " + Select the marginal effects plot to view. Three options are available: + + 1) **Accumulated Local Effects Plot (ALE)**. Broadly, this can be interpreted as *the effect of a given variable on the total probability of the species occurence.* + + 2) **Partial Dependence Plot (PDP)**. Broadly, this can be interpreted as *...* + + 3) **Ceritus Paribus Plot (CP)**. Broadly, this can be interpreted as *the probability of occurrence for that variable value, with all other variables held constant.* + + Please see the documentation for more details. + " + ), + placement = "bottom" + ) + + ), + + shiny::div(shiny::br()) + ) ),