Skip to content

Commit

Permalink
Reading model display from duckdb
Browse files Browse the repository at this point in the history
  • Loading branch information
ZekeMarshall committed Oct 16, 2024
1 parent 15e81e6 commit 512b253
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 23 deletions.
6 changes: 5 additions & 1 deletion inst/app/app.R
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ library(mlr3pipelines)
library(mlr3learners)
library(mlr3extralearners)
library(targets)
library(DBI)
library(dbplyr)
library(duckdb)
library(qs)
library(stats)
library(DALEX)
Expand All @@ -59,11 +62,12 @@ library(DALEXtra)
source("./../../R/temp_functions.R", local = TRUE)
source("./../../R/graph_functions.R", local = TRUE)
tar_store <- file.path("C:/Users/zekmar/Github/GBIENMAnalysis/_targets")
db_path <- file.path("C:", "Users", "zekmar", "OneDrive - UKCEH", "GBIENMWorkingDir", "OutputData")

modelled_species <- targets::tar_read(name = "Species", store = tar_store)

mlr3extralearners::install_learners(c("classif.gam", "classif.randomForest"))


# Render documentation ----------------------------------------------------
# rmarkdown::render(input = "./inst/app/docs/documentation.Rmd", output_dir = "./inst/app/www")

Expand Down
5 changes: 5 additions & 0 deletions inst/app/modules/data_entry/surveyDataValidator_ui.R
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ surveyDataValidatorUI <- function(id){
shiny::div(
shiny::actionButton(inputId = ns("reallocateGroups"),
label = "Re-allocate Groups")
),

shiny::div(
shiny::actionButton(inputId = ns("trimWS"),
label = "Trim White Space")
)

),
Expand Down
65 changes: 52 additions & 13 deletions inst/app/modules/niche_models/nmModelDisplay_server.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@ nmModelDisplay <- function(input, output, session, sidebar_nm_options) {
focalSpecies <- reactiveVal()
selectedModelDisplay <- reactiveVal()
selectedVariablesDisplay <- reactiveVal()
selectedMarginalEffectsPlot <- reactiveVal()

observe({

focalSpecies(sidebar_nm_options()$focalSpecies)
selectedModelDisplay(sidebar_nm_options()$selectedModelDisplay)
selectedVariablesDisplay(sidebar_nm_options()$selectedVariablesDisplay)
selectedMarginalEffectsPlot(sidebar_nm_options()$selectedMarginalEffectsPlot)

}) |>
bindEvent(sidebar_nm_options(),
Expand All @@ -20,7 +22,7 @@ nmModelDisplay <- function(input, output, session, sidebar_nm_options) {

# Retrieve Data -----------------------------------------------------------
measures_rval <- reactiveVal()
aleData_rval <- reactiveVal()
meData_rval <- reactiveVal()
featureImportance_rval <- reactiveVal()

observe({
Expand All @@ -34,26 +36,62 @@ nmModelDisplay <- function(input, output, session, sidebar_nm_options) {

focalSpecies <- focalSpecies()
selectedModelDisplay <- selectedModelDisplay()
selectedMarginalEffectsPlot <- selectedMarginalEffectsPlot()
selectedVariablesDisplay <- c("_full_model_", "_baseline_", selectedVariablesDisplay())

measures <- targets::tar_read(name = "AllMeasures", store = tar_store) |>
dplyr::filter(species == focalSpecies) |>
dplyr::filter(model %in% selectedModelDisplay)
# Open connection
con <- DBI::dbConnect(duckdb::duckdb(),
dbdir = file.path(db_path, "biens-db.duckdb"),
read_only = TRUE)

aleData <- targets::tar_read(name = "AllALEData", store = tar_store) |>
# Retrieve measures
measures <- dplyr::tbl(src = con, "AllMeasures") |>
dplyr::filter(species == focalSpecies) |>
dplyr::filter(model %in% selectedModelDisplay) |>
dplyr::filter(variable %in% selectedVariablesDisplay)
dplyr::collect()

featureImportance <- targets::tar_read(name = "AllFeatureImportance", store = tar_store) |>
# Retrieve marginal effects
if(selectedMarginalEffectsPlot == "ALE"){

meData <- dplyr::tbl(src = con, "AllALEData") |>
dplyr::filter(species == focalSpecies) |>
dplyr::filter(model %in% selectedModelDisplay) |>
dplyr::filter(variable %in% selectedVariablesDisplay) |>
dplyr::collect()

} else if(selectedMarginalEffectsPlot == "PDP"){

meData <- dplyr::tbl(src = con, "AllPDPData") |>
dplyr::filter(species == focalSpecies) |>
dplyr::filter(model %in% selectedModelDisplay) |>
dplyr::filter(variable %in% selectedVariablesDisplay) |>
dplyr::collect()

} else if(selectedMarginalEffectsPlot == "CP"){

meData <- dplyr::tbl(src = con, "AllCDData") |>
dplyr::filter(species == focalSpecies) |>
dplyr::filter(model %in% selectedModelDisplay) |>
dplyr::filter(variable %in% selectedVariablesDisplay) |>
dplyr::collect()

}

# Retrieve feature importance
featureImportance <- dplyr::tbl(src = con, "AllFeatureImportance") |>
dplyr::filter(species == focalSpecies) |>
dplyr::filter(model %in% selectedModelDisplay) |>
dplyr::filter(variable %in% selectedVariablesDisplay)
dplyr::filter(variable %in% selectedVariablesDisplay) |>
dplyr::collect()

# Close connection
DBI::dbDisconnect(conn = con)

assign(x = "featureImportance", value = featureImportance, envir = .GlobalEnv)
# assign(x = "aleData", value = aleData, envir = .GlobalEnv)

# Store data in reactive objects
measures_rval(measures)
aleData_rval(aleData)
meData_rval(meData)
featureImportance_rval(featureImportance)

# Stop busy spinner
Expand All @@ -63,6 +101,7 @@ nmModelDisplay <- function(input, output, session, sidebar_nm_options) {
bindEvent(focalSpecies(),
selectedModelDisplay(),
selectedVariablesDisplay(),
selectedMarginalEffectsPlot(),
ignoreInit = TRUE)


Expand Down Expand Up @@ -214,11 +253,11 @@ nmModelDisplay <- function(input, output, session, sidebar_nm_options) {
text = "Rendering ALE Plot"
)

aleData <- aleData_rval()
meData <- meData_rval()

output$ale_plot <- plotly::renderPlotly({

ale_plot <- create_ale_plot(ale_data = aleData)
ale_plot <- create_ale_plot(ale_data = meData)

ale_plotly <- plotly::ggplotly(ale_plot)

Expand All @@ -231,7 +270,7 @@ nmModelDisplay <- function(input, output, session, sidebar_nm_options) {


}) |>
bindEvent(aleData_rval(),
bindEvent(meData_rval(),
ignoreInit = FALSE)


Expand Down
7 changes: 0 additions & 7 deletions inst/app/modules/niche_models/nmModelRun_server.R
Original file line number Diff line number Diff line change
Expand Up @@ -81,20 +81,13 @@ nmModelRun <- function(input, output, session, sidebar_nm_options, nmDataInput)
identifyPredDrivers <- identifyPredDrivers()
})

# assign(x = "predictorData", value = predictorData, envir = .GlobalEnv)
# assign(x = "selectedModel", value = selectedModel, envir = .GlobalEnv)
# assign(x = "selectedExplainer", value = selectedExplainer, envir = .GlobalEnv)
# assign(x = "identifyPredDrivers", value = identifyPredDrivers, envir = .GlobalEnv)

results <- predict(selectedModel, newdata = predictorData, predict_type = "prob") |>
tibble::as_tibble() |>
dplyr::bind_cols(dplyr::select(predictors, id)) |>
dplyr::select("id" = "id",
"Prob.Presence" = "Present",
"Prob.Absence" = "Absent")

print(results)


# Identify model prediction drivers
if(isTRUE(identifyPredDrivers)){
Expand Down
2 changes: 2 additions & 0 deletions inst/app/modules/niche_models/nmSidebar_server.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ nmSidebar <- function(input, output, session) {
"focalSpecies" = input$focalSpecies,
"selectedModelDisplay" = input$selectedModelDisplay,
"selectedVariablesDisplay" = input$selectedVariablesDisplay,
"selectedMarginalEffectsPlot" = input$selectedMarginalEffectsPlot,
"identifyPredDrivers" = input$identifyPredDrivers,
"selectedModelPredict" = input$selectedModelPredict

Expand All @@ -24,6 +25,7 @@ nmSidebar <- function(input, output, session) {
input$focalSpecies,
input$selectedModelDisplay,
input$selectedVariablesDisplay,
input$selectedMarginalEffectsPlot,
input$identifyPredDrivers,
input$selectedModelPredict,
ignoreInit = TRUE)
Expand Down
44 changes: 42 additions & 2 deletions inst/app/modules/niche_models/nmSidebar_ui.R
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ nmSidebarUI <- function(id){
col_widths = c(11, 1),

shiny::selectizeInput(inputId = ns("selectedModelDisplay"),
label = "Selected Model",
label = "Model",
choices = c("GAM", "NNet", "GLM", "RF", "MARS", "SVM", "XGB", "WE"),
selected = c("GAM", "NNet", "GLM", "RF", "MARS", "SVM", "XGB", "WE"),
multiple = TRUE),
Expand Down Expand Up @@ -107,7 +107,7 @@ nmSidebarUI <- function(id){
col_widths = c(11, 1),

shiny::selectizeInput(inputId = ns("selectedVariablesDisplay"),
label = "Selected Variables",
label = "Variables",
choices = c("F", "L", "N", "R", "S", "DG", "DS", "H"),
selected = c("F", "L", "N", "R", "S", "DG", "DS", "H"),
multiple = TRUE),
Expand All @@ -128,6 +128,46 @@ nmSidebarUI <- function(id){

shiny::div(shiny::br())

),

shiny::div(

id = ns("selectedMarginalEffectsPlot_div"),

bslib::layout_columns(

col_widths = c(11, 1),

shiny::selectizeInput(inputId = ns("selectedMarginalEffectsPlot"),
label = "Marginal Effects Plot",
choices = c("ALE", "PDP", "CP"),
selected = "ALE",
multiple = FALSE),

bslib::popover(
bsicons::bs_icon("info-circle"),
title = "Marginal Effects Plot Selection",
id = ns("selectedMarginalEffectsPlot_info"),
shiny::markdown(
"
Select the marginal effects plot to view. Three options are available:
1) **Accumulated Local Effects Plot (ALE)**. Broadly, this can be interpreted as *the effect of a given variable on the total probability of the species occurence.*
2) **Partial Dependence Plot (PDP)**. Broadly, this can be interpreted as *...*
3) **Ceritus Paribus Plot (CP)**. Broadly, this can be interpreted as *the probability of occurrence for that variable value, with all other variables held constant.*
Please see the documentation for more details.
"
),
placement = "bottom"
)

),

shiny::div(shiny::br())

)

),
Expand Down

0 comments on commit 512b253

Please sign in to comment.