Skip to content

Commit

Permalink
Binomial tests
Browse files Browse the repository at this point in the history
Also move update tests to separate file to speed up parallel testing
  • Loading branch information
fseaton committed Dec 20, 2023
1 parent 9b7f25f commit e5199e4
Show file tree
Hide file tree
Showing 4 changed files with 171 additions and 42 deletions.
28 changes: 28 additions & 0 deletions tests/testthat/test-posterior_predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,31 @@ test_that("posterior_(lin)pred works with mglmm and negbin", {
expect_false(any(sapply(negb_pred2, anyNA)))
expect_false(any(sapply(negb_pred2, function(x) x < 0)))
})

bino_sim_data <- gllvm_sim_data(N = 100, S = 9, K = 2, family = "binomial",
site_intercept = "ungrouped", D = 2,
Ntrials = 20)
bino_pred_data <- matrix(rnorm(100 * 2), nrow = 100)
colnames(bino_pred_data) <- c("V1", "V2")
suppressWarnings(bino_fit <- stan_mglmm(
dat_list = bino_sim_data, family = "binomial",
refresh = 0, chains = 2, iter = 500
))
test_that("posterior_(lin)pred works with gllvm and bino", {
bino_pred <- posterior_predict(bino_fit, ndraws = 100)

expect_length(bino_pred, 100)
expect_false(any(sapply(bino_pred, anyNA)))
expect_false(any(sapply(bino_pred, function(x) x < 0)))
expect_false(any(sapply(bino_pred, function(x) x > 20)))

bino_pred2 <- posterior_predict(bino_fit,
newdata = bino_pred_data, Ntrials = 16,
ndraws = 50, list_index = "species"
)

expect_length(bino_pred2, 9)
expect_false(any(sapply(bino_pred2, anyNA)))
expect_false(any(sapply(bino_pred2, function(x) x < 0)))
expect_false(any(sapply(bino_pred2, function(x) x > 16)))
})
26 changes: 24 additions & 2 deletions tests/testthat/test-sim_data_funs.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ test_that("gllvm_sim_data errors with bad inputs", {
"Grouped site intercept not supported"
)

expect_error(
gllvm_sim_data(N = 200, S = 8, D = 2, family = "binomial", Ntrials = "1"),
"Ntrials must be a positive integer"
)

})

test_that("gllvm_sim_data returns a list of correct length", {
Expand Down Expand Up @@ -51,6 +56,16 @@ test_that("mglmm_sim_data errors with bad inputs", {
site_intercept = "grouped"),
"Grouped site intercept not supported"
)

expect_error(
mglmm_sim_data(N = 100, S = 5, family = "binomial"),
"Number of trials must be specified"
)

expect_error(
mglmm_sim_data(N = 50, S = 8, family = "binomial", Ntrials = c(1,3)),
"Ntrials must be of length"
)
})


Expand All @@ -66,13 +81,19 @@ test_that("mglmm_sim_data returns a list of correct length", {
expect_named(mglmm_sim, c(
"Y", "pars", "N", "S", "D", "K", "X"
))
gllvm_sim <- jsdm_sim_data(100,12,D=2,family = "binomial", method = "gllvm",
Ntrials = 19)
expect_named(gllvm_sim, c(
"Y", "pars", "N", "S", "D", "K", "X", "Ntrials"
))
expect_length(gllvm_sim$Ntrials, 100)
})

test_that("jsdm_sim_data returns all appropriate pars", {
mglmm_sim <- jsdm_sim_data(100,12,family = "gaussian", method = "mglmm",
beta_param = "cor")
beta_param = "cor", K = 3)
expect_named(mglmm_sim$pars, c(
"betas","sigmas_preds","z_preds","sigmas_species",
"betas","sigmas_preds","z_preds","cor_preds","sigmas_species",
"cor_species","z_species","sigma"
))
gllvm_sim <- jsdm_sim_data(100,12,D=2,family = "neg_bin", method = "gllvm",
Expand All @@ -83,6 +104,7 @@ test_that("jsdm_sim_data returns all appropriate pars", {
))



})

test_that("prior specification works", {
Expand Down
95 changes: 55 additions & 40 deletions tests/testthat/test-stan_jsdm.R
Original file line number Diff line number Diff line change
Expand Up @@ -43,27 +43,6 @@ test_that("summary works", {
expect_match(rownames(mglmm_summ), "beta", all = TRUE)
})

test_that("update works", {
mglmm_data <- mglmm_sim_data(N = 20, S = 5, family = "gaussian", K = 2)
suppressWarnings(mglmm_fit2 <- update(mglmm_fit,
newY = mglmm_data$Y,
newX = mglmm_data$X,
refresh = 0,
chains = 2, iter = 200
))
suppressWarnings(mglmm_fit3 <- update(mglmm_fit,
refresh = 0, iter = 100
))

expect_s3_class(mglmm_fit2, "jsdmStanFit")
expect_s3_class(mglmm_fit3, "jsdmStanFit")

jsdm_empty <- jsdmStanFit_empty()
expect_error(
update(jsdm_empty),
"Update requires the original data to be saved in the model object"
)
})

test_that("nuts_params works", {
expect_named(nuts_params(mglmm_fit), c("Chain", "Iteration", "Parameter", "Value"))
Expand Down Expand Up @@ -126,6 +105,23 @@ test_that("stan_gllvm fails with wrong inputs", {
family = "bern", D = -1),
"Must have at least one latent variable"
)

expect_error(
suppressMessages(stan_gllvm(Y = matrix(sample.int(10,100,
replace = TRUE),
nrow = 20),
X = matrix(rnorm(60), nrow = 20),
family = "binomial", D = 2, Ntrials = "a")),
"Ntrials must be a positive integer"
)

expect_error(
suppressMessages(stan_gllvm(Y = matrix(sample.int(10,100,
replace = TRUE), nrow = 20),
X = matrix(rnorm(60), nrow = 20),
family = "binomial", D = 2, Ntrials = c(1,3))),
"Ntrials must be of length"
)
})

test_that("stan_gllvm returns right type of object", {
Expand Down Expand Up @@ -165,6 +161,17 @@ test_that("stan_gllvm returns right type of object", {
))

expect_s3_class(gllvm_fit, "jsdmStanFit")

# binomial
gllvm_data <- gllvm_sim_data(N = 20, S = 8, D = 2, K = 2, family = "binomial",
Ntrials = 20)
suppressWarnings(gllvm_fit <- stan_gllvm(
Y = as.data.frame(gllvm_data$Y), X = as.data.frame(gllvm_data$X),
D = gllvm_data$D, refresh = 0, chains = 2, iter = 200,
family = "binomial", Ntrials = 20
))

expect_s3_class(gllvm_fit, "jsdmStanFit")
})

# MGLMM tests ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand All @@ -187,6 +194,15 @@ test_that("stan_mglmm fails with wrong inputs", {
),
"Y matrix is not composed of integers"
)
expect_error(
stan_mglmm(
Y = matrix(sample.int(100,500,replace=TRUE), nrow = 100),
X = data.frame(V1 = rnorm(100),
V2 = rnorm(100)),
family = "binomial"
),
"Number of trials must be specified"
)
})

test_that("stan_mglmm returns right type of object", {
Expand Down Expand Up @@ -216,6 +232,10 @@ test_that("stan_mglmm returns right type of object", {
chains = 2
))

expect_error(stan_jsdm(dat_list = mglmm_data, family = "binomial",
method = "mglmm"),
"Binomial models require Ntrials")

expect_s3_class(mglmm_fit, "jsdmStanFit")

# neg bin
Expand All @@ -227,6 +247,20 @@ test_that("stan_mglmm returns right type of object", {
))

expect_s3_class(mglmm_fit, "jsdmStanFit")

# binomial
mglmm_data <- mglmm_sim_data(N = 51, S = 6, K = 2, family = "binomial",
Ntrials=sample.int(20,51,replace = TRUE))
suppressWarnings(mglmm_fit <- stan_mglmm(
dat_list = mglmm_data,
family = "binomial",
refresh = 0, chains = 2, iter = 200
))

expect_s3_class(mglmm_fit, "jsdmStanFit")

# binomial

})

# stan_jsdm site_intercept tests ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand All @@ -235,11 +269,6 @@ mglmm_data <- mglmm_sim_data(N = 100, S = 8, family = "gaussian", K = 3,
df <- as.data.frame(mglmm_data$X)
grps <- rep(1:20, each = 5)

gllvm_data <- gllvm_sim_data(N = 100, S = 8, family = "bern", D = 3,
site_intercept = "ungrouped")
gllvm_data$grps <- rep(1:20, each = 5)
gllvm_data$ngrp <- 20

test_that("site intercept errors correctly", {
expect_error(stan_mglmm(~ V1 + V2, data = df, Y = mglmm_data$Y,
site_intercept = "fauh",family = "gaussian"))
Expand Down Expand Up @@ -270,17 +299,3 @@ test_that("site intercept models run", {
expect_s3_class(mglmm_fit, "jsdmStanFit")

})

test_that("site_intercept models update", {
suppressWarnings(gllvm_fit <- stan_gllvm(X = NULL, dat_list = gllvm_data,
site_intercept = "grouped",
family = "bern",
refresh = 0, chains = 1, iter = 200
))
expect_s3_class(gllvm_fit, "jsdmStanFit")

suppressWarnings(gllvm_fit2 <- update(gllvm_fit, newD = 2,
refresh = 0, chains = 1, iter = 200
))
expect_s3_class(gllvm_fit2, "jsdmStanFit")
})
64 changes: 64 additions & 0 deletions tests/testthat/test-update.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
mglmm_data <- mglmm_sim_data(N = 30, S = 8, family = "gaussian", K = 3)
df <- as.data.frame(mglmm_data$X)

suppressWarnings(mglmm_fit <- stan_jsdm(~ V1 + V2 + V3,
data = df, Y = mglmm_data$Y,
family = "gaussian", method = "mglmm",
refresh = 0, chains = 2, iter = 200
))
test_that("update works", {
mglmm_data <- mglmm_sim_data(N = 20, S = 5, family = "gaussian", K = 2)
suppressWarnings(mglmm_fit2 <- update(mglmm_fit,
newY = mglmm_data$Y,
newX = mglmm_data$X,
refresh = 0,
chains = 2, iter = 200
))
suppressWarnings(mglmm_fit3 <- update(mglmm_fit,
refresh = 0, iter = 100
))

expect_s3_class(mglmm_fit2, "jsdmStanFit")
expect_s3_class(mglmm_fit3, "jsdmStanFit")

jsdm_empty <- jsdmStanFit_empty()
expect_error(
update(jsdm_empty),
"Update requires the original data to be saved in the model object"
)
})

gllvm_data <- gllvm_sim_data(N = 100, S = 8, family = "bern", D = 3,
site_intercept = "ungrouped")
gllvm_data$grps <- rep(1:20, each = 5)
gllvm_data$ngrp <- 20

test_that("site_intercept models update", {
suppressWarnings(gllvm_fit <- stan_gllvm(X = NULL, dat_list = gllvm_data,
site_intercept = "grouped",
family = "bern",
refresh = 0, chains = 1, iter = 200
))
expect_s3_class(gllvm_fit, "jsdmStanFit")

suppressWarnings(gllvm_fit2 <- update(gllvm_fit, newD = 2,
refresh = 0, chains = 1, iter = 200
))
expect_s3_class(gllvm_fit2, "jsdmStanFit")
})

gllvm_data <- gllvm_sim_data(N = 100, S = 8, family = "binomial", D = 3,
site_intercept = "ungrouped", Ntrials = 20)

test_that("binomial models update", {
suppressWarnings(gllvm_fit <- stan_gllvm(dat_list = gllvm_data,
family = "binomial",
refresh = 0, chains = 1, iter = 200
))
expect_s3_class(gllvm_fit, "jsdmStanFit")

suppressWarnings(gllvm_fit2 <- update(gllvm_fit, newD = 2, newNtrials = 30,
refresh = 0, chains = 1, iter = 200
))
expect_s3_class(gllvm_fit2, "jsdmStanFit")
})

0 comments on commit e5199e4

Please sign in to comment.