From e5199e4c8026b04fa0e9782023df7eab9ca4b17e Mon Sep 17 00:00:00 2001 From: Fiona Seaton Date: Wed, 20 Dec 2023 15:12:46 +0000 Subject: [PATCH] Binomial tests Also move update tests to separate file to speed up parallel testing --- tests/testthat/test-posterior_predict.R | 28 ++++++++ tests/testthat/test-sim_data_funs.R | 26 ++++++- tests/testthat/test-stan_jsdm.R | 95 ++++++++++++++----------- tests/testthat/test-update.R | 64 +++++++++++++++++ 4 files changed, 171 insertions(+), 42 deletions(-) create mode 100644 tests/testthat/test-update.R diff --git a/tests/testthat/test-posterior_predict.R b/tests/testthat/test-posterior_predict.R index 4cc895e..6fd3697 100644 --- a/tests/testthat/test-posterior_predict.R +++ b/tests/testthat/test-posterior_predict.R @@ -122,3 +122,31 @@ test_that("posterior_(lin)pred works with mglmm and negbin", { expect_false(any(sapply(negb_pred2, anyNA))) expect_false(any(sapply(negb_pred2, function(x) x < 0))) }) + +bino_sim_data <- gllvm_sim_data(N = 100, S = 9, K = 2, family = "binomial", + site_intercept = "ungrouped", D = 2, + Ntrials = 20) +bino_pred_data <- matrix(rnorm(100 * 2), nrow = 100) +colnames(bino_pred_data) <- c("V1", "V2") +suppressWarnings(bino_fit <- stan_mglmm( + dat_list = bino_sim_data, family = "binomial", + refresh = 0, chains = 2, iter = 500 +)) +test_that("posterior_(lin)pred works with gllvm and bino", { + bino_pred <- posterior_predict(bino_fit, ndraws = 100) + + expect_length(bino_pred, 100) + expect_false(any(sapply(bino_pred, anyNA))) + expect_false(any(sapply(bino_pred, function(x) x < 0))) + expect_false(any(sapply(bino_pred, function(x) x > 20))) + + bino_pred2 <- posterior_predict(bino_fit, + newdata = bino_pred_data, Ntrials = 16, + ndraws = 50, list_index = "species" + ) + + expect_length(bino_pred2, 9) + expect_false(any(sapply(bino_pred2, anyNA))) + expect_false(any(sapply(bino_pred2, function(x) x < 0))) + expect_false(any(sapply(bino_pred2, function(x) x > 16))) +}) diff --git a/tests/testthat/test-sim_data_funs.R b/tests/testthat/test-sim_data_funs.R index 887ebd4..10793ac 100644 --- a/tests/testthat/test-sim_data_funs.R +++ b/tests/testthat/test-sim_data_funs.R @@ -18,6 +18,11 @@ test_that("gllvm_sim_data errors with bad inputs", { "Grouped site intercept not supported" ) + expect_error( + gllvm_sim_data(N = 200, S = 8, D = 2, family = "binomial", Ntrials = "1"), + "Ntrials must be a positive integer" + ) + }) test_that("gllvm_sim_data returns a list of correct length", { @@ -51,6 +56,16 @@ test_that("mglmm_sim_data errors with bad inputs", { site_intercept = "grouped"), "Grouped site intercept not supported" ) + + expect_error( + mglmm_sim_data(N = 100, S = 5, family = "binomial"), + "Number of trials must be specified" + ) + + expect_error( + mglmm_sim_data(N = 50, S = 8, family = "binomial", Ntrials = c(1,3)), + "Ntrials must be of length" + ) }) @@ -66,13 +81,19 @@ test_that("mglmm_sim_data returns a list of correct length", { expect_named(mglmm_sim, c( "Y", "pars", "N", "S", "D", "K", "X" )) + gllvm_sim <- jsdm_sim_data(100,12,D=2,family = "binomial", method = "gllvm", + Ntrials = 19) + expect_named(gllvm_sim, c( + "Y", "pars", "N", "S", "D", "K", "X", "Ntrials" + )) + expect_length(gllvm_sim$Ntrials, 100) }) test_that("jsdm_sim_data returns all appropriate pars", { mglmm_sim <- jsdm_sim_data(100,12,family = "gaussian", method = "mglmm", - beta_param = "cor") + beta_param = "cor", K = 3) expect_named(mglmm_sim$pars, c( - "betas","sigmas_preds","z_preds","sigmas_species", + "betas","sigmas_preds","z_preds","cor_preds","sigmas_species", "cor_species","z_species","sigma" )) gllvm_sim <- jsdm_sim_data(100,12,D=2,family = "neg_bin", method = "gllvm", @@ -83,6 +104,7 @@ test_that("jsdm_sim_data returns all appropriate pars", { )) + }) test_that("prior specification works", { diff --git a/tests/testthat/test-stan_jsdm.R b/tests/testthat/test-stan_jsdm.R index 0741dee..6dfab1f 100644 --- a/tests/testthat/test-stan_jsdm.R +++ b/tests/testthat/test-stan_jsdm.R @@ -43,27 +43,6 @@ test_that("summary works", { expect_match(rownames(mglmm_summ), "beta", all = TRUE) }) -test_that("update works", { - mglmm_data <- mglmm_sim_data(N = 20, S = 5, family = "gaussian", K = 2) - suppressWarnings(mglmm_fit2 <- update(mglmm_fit, - newY = mglmm_data$Y, - newX = mglmm_data$X, - refresh = 0, - chains = 2, iter = 200 - )) - suppressWarnings(mglmm_fit3 <- update(mglmm_fit, - refresh = 0, iter = 100 - )) - - expect_s3_class(mglmm_fit2, "jsdmStanFit") - expect_s3_class(mglmm_fit3, "jsdmStanFit") - - jsdm_empty <- jsdmStanFit_empty() - expect_error( - update(jsdm_empty), - "Update requires the original data to be saved in the model object" - ) -}) test_that("nuts_params works", { expect_named(nuts_params(mglmm_fit), c("Chain", "Iteration", "Parameter", "Value")) @@ -126,6 +105,23 @@ test_that("stan_gllvm fails with wrong inputs", { family = "bern", D = -1), "Must have at least one latent variable" ) + + expect_error( + suppressMessages(stan_gllvm(Y = matrix(sample.int(10,100, + replace = TRUE), + nrow = 20), + X = matrix(rnorm(60), nrow = 20), + family = "binomial", D = 2, Ntrials = "a")), + "Ntrials must be a positive integer" + ) + + expect_error( + suppressMessages(stan_gllvm(Y = matrix(sample.int(10,100, + replace = TRUE), nrow = 20), + X = matrix(rnorm(60), nrow = 20), + family = "binomial", D = 2, Ntrials = c(1,3))), + "Ntrials must be of length" + ) }) test_that("stan_gllvm returns right type of object", { @@ -165,6 +161,17 @@ test_that("stan_gllvm returns right type of object", { )) expect_s3_class(gllvm_fit, "jsdmStanFit") + + # binomial + gllvm_data <- gllvm_sim_data(N = 20, S = 8, D = 2, K = 2, family = "binomial", + Ntrials = 20) + suppressWarnings(gllvm_fit <- stan_gllvm( + Y = as.data.frame(gllvm_data$Y), X = as.data.frame(gllvm_data$X), + D = gllvm_data$D, refresh = 0, chains = 2, iter = 200, + family = "binomial", Ntrials = 20 + )) + + expect_s3_class(gllvm_fit, "jsdmStanFit") }) # MGLMM tests ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -187,6 +194,15 @@ test_that("stan_mglmm fails with wrong inputs", { ), "Y matrix is not composed of integers" ) + expect_error( + stan_mglmm( + Y = matrix(sample.int(100,500,replace=TRUE), nrow = 100), + X = data.frame(V1 = rnorm(100), + V2 = rnorm(100)), + family = "binomial" + ), + "Number of trials must be specified" + ) }) test_that("stan_mglmm returns right type of object", { @@ -216,6 +232,10 @@ test_that("stan_mglmm returns right type of object", { chains = 2 )) + expect_error(stan_jsdm(dat_list = mglmm_data, family = "binomial", + method = "mglmm"), + "Binomial models require Ntrials") + expect_s3_class(mglmm_fit, "jsdmStanFit") # neg bin @@ -227,6 +247,20 @@ test_that("stan_mglmm returns right type of object", { )) expect_s3_class(mglmm_fit, "jsdmStanFit") + + # binomial + mglmm_data <- mglmm_sim_data(N = 51, S = 6, K = 2, family = "binomial", + Ntrials=sample.int(20,51,replace = TRUE)) + suppressWarnings(mglmm_fit <- stan_mglmm( + dat_list = mglmm_data, + family = "binomial", + refresh = 0, chains = 2, iter = 200 + )) + + expect_s3_class(mglmm_fit, "jsdmStanFit") + + # binomial + }) # stan_jsdm site_intercept tests ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -235,11 +269,6 @@ mglmm_data <- mglmm_sim_data(N = 100, S = 8, family = "gaussian", K = 3, df <- as.data.frame(mglmm_data$X) grps <- rep(1:20, each = 5) -gllvm_data <- gllvm_sim_data(N = 100, S = 8, family = "bern", D = 3, - site_intercept = "ungrouped") -gllvm_data$grps <- rep(1:20, each = 5) -gllvm_data$ngrp <- 20 - test_that("site intercept errors correctly", { expect_error(stan_mglmm(~ V1 + V2, data = df, Y = mglmm_data$Y, site_intercept = "fauh",family = "gaussian")) @@ -270,17 +299,3 @@ test_that("site intercept models run", { expect_s3_class(mglmm_fit, "jsdmStanFit") }) - -test_that("site_intercept models update", { - suppressWarnings(gllvm_fit <- stan_gllvm(X = NULL, dat_list = gllvm_data, - site_intercept = "grouped", - family = "bern", - refresh = 0, chains = 1, iter = 200 - )) - expect_s3_class(gllvm_fit, "jsdmStanFit") - - suppressWarnings(gllvm_fit2 <- update(gllvm_fit, newD = 2, - refresh = 0, chains = 1, iter = 200 - )) - expect_s3_class(gllvm_fit2, "jsdmStanFit") -}) diff --git a/tests/testthat/test-update.R b/tests/testthat/test-update.R new file mode 100644 index 0000000..1bee5e1 --- /dev/null +++ b/tests/testthat/test-update.R @@ -0,0 +1,64 @@ +mglmm_data <- mglmm_sim_data(N = 30, S = 8, family = "gaussian", K = 3) +df <- as.data.frame(mglmm_data$X) + +suppressWarnings(mglmm_fit <- stan_jsdm(~ V1 + V2 + V3, + data = df, Y = mglmm_data$Y, + family = "gaussian", method = "mglmm", + refresh = 0, chains = 2, iter = 200 +)) +test_that("update works", { + mglmm_data <- mglmm_sim_data(N = 20, S = 5, family = "gaussian", K = 2) + suppressWarnings(mglmm_fit2 <- update(mglmm_fit, + newY = mglmm_data$Y, + newX = mglmm_data$X, + refresh = 0, + chains = 2, iter = 200 + )) + suppressWarnings(mglmm_fit3 <- update(mglmm_fit, + refresh = 0, iter = 100 + )) + + expect_s3_class(mglmm_fit2, "jsdmStanFit") + expect_s3_class(mglmm_fit3, "jsdmStanFit") + + jsdm_empty <- jsdmStanFit_empty() + expect_error( + update(jsdm_empty), + "Update requires the original data to be saved in the model object" + ) +}) + +gllvm_data <- gllvm_sim_data(N = 100, S = 8, family = "bern", D = 3, + site_intercept = "ungrouped") +gllvm_data$grps <- rep(1:20, each = 5) +gllvm_data$ngrp <- 20 + +test_that("site_intercept models update", { + suppressWarnings(gllvm_fit <- stan_gllvm(X = NULL, dat_list = gllvm_data, + site_intercept = "grouped", + family = "bern", + refresh = 0, chains = 1, iter = 200 + )) + expect_s3_class(gllvm_fit, "jsdmStanFit") + + suppressWarnings(gllvm_fit2 <- update(gllvm_fit, newD = 2, + refresh = 0, chains = 1, iter = 200 + )) + expect_s3_class(gllvm_fit2, "jsdmStanFit") +}) + +gllvm_data <- gllvm_sim_data(N = 100, S = 8, family = "binomial", D = 3, + site_intercept = "ungrouped", Ntrials = 20) + +test_that("binomial models update", { + suppressWarnings(gllvm_fit <- stan_gllvm(dat_list = gllvm_data, + family = "binomial", + refresh = 0, chains = 1, iter = 200 + )) + expect_s3_class(gllvm_fit, "jsdmStanFit") + + suppressWarnings(gllvm_fit2 <- update(gllvm_fit, newD = 2, newNtrials = 30, + refresh = 0, chains = 1, iter = 200 + )) + expect_s3_class(gllvm_fit2, "jsdmStanFit") +})