Skip to content

Commit

Permalink
fix chol problem and beta param default
Browse files Browse the repository at this point in the history
  • Loading branch information
fseaton committed Dec 14, 2023
1 parent f94868d commit 44f40b1
Show file tree
Hide file tree
Showing 11 changed files with 109 additions and 52 deletions.
4 changes: 2 additions & 2 deletions R/jsdm_stancode.R
Original file line number Diff line number Diff line change
Expand Up @@ -105,13 +105,13 @@ jsdm_stancode <- function(method, family, prior = jsdm_prior(),
vector<lower=0>[K] sigmas_preds;
matrix[K, S] z_preds;
// covariance matrix on betas by predictors
cholesky_factor_corr[K] cor_preds;", "unstruct" = "
corr_matrix[K] cor_preds;", "unstruct" = "
matrix[K, S] betas;")
mglmm_spcov_pars <- "
// species covariances
vector<lower=0>[S] sigmas_species;
matrix[S, N] z_species;
cholesky_factor_corr[S] cor_species;"
corr_matrix[S] cor_species;"
gllvm_pars <- "
// Factor parameters
vector[M] L; // Non-zero factor loadings
Expand Down
10 changes: 5 additions & 5 deletions R/prior.R
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@
#' to be positive (default standard normal)
#' @param z_preds The covariate effects (default standard normal)
#' @param cor_preds The correlation matrix on the covariate effects (npred by npred
#' matrix represented as a Cholesky factor of a correlation matrix) (default
#' \code{"lkj_corr_cholesky(1)"})
#' correlation matrix) (default
#' \code{"lkj_corr(1)"})
#' @param betas If covariate effects are unstructured, the prior on the covariate
#' effects
#' @param a The site level intercepts (default standard normal)
Expand All @@ -43,7 +43,7 @@
#' @param z_species For MGLMM method, the S by N matrix of species covariance by site
#' (default standard normal)
#' @param cor_species For MGLMM method, the correlation between species represented
#' as a Cholesky factor correlation matrix (default \code{"lkj_corr_cholesky(1)"})
#' as a nspecies by nspecies correlation matrix (default \code{"lkj_corr(1)"})
#' @param LV For GLLVM method, the per site latent variable loadings (default
#' standard normal)
#' @param L For GLLVM method, the non-zero species latent variable loadings (default
Expand All @@ -64,14 +64,14 @@
#'
jsdm_prior <- function(sigmas_preds = "normal(0,1)",
z_preds = "normal(0,1)",
cor_preds = "lkj_corr_cholesky(1)",
cor_preds = "lkj_corr(1)",
betas = "normal(0,1)",
a = "normal(0,1)",
a_bar = "normal(0,1)",
sigma_a = "normal(0,1)",
sigmas_species = "normal(0,1)",
z_species = "normal(0,1)",
cor_species = "lkj_corr_cholesky(1)",
cor_species = "lkj_corr(1)",
LV = "normal(0,1)",
L = "normal(0,1)",
sigma_L = "normal(0,1)",
Expand Down
48 changes: 31 additions & 17 deletions R/sim_data_funs.R
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,13 @@
#' currently.
#'
#' @param beta_param The parameterisation of the environmental covariate effects, by
#' default \code{"cor"}. See details for further information.
#' default \code{"unstruct"}. See details for further information.
#'
#' @param prior Set of prior specifications from call to [jsdm_prior()]
jsdm_sim_data <- function(N, S, D = NULL, K = 0L, family, method = c("gllvm", "mglmm"),
species_intercept = TRUE,
site_intercept = "none",
beta_param = "cor",
beta_param = "unstruct",
prior = jsdm_prior()) {
response <- match.arg(family, c("gaussian", "neg_binomial", "poisson", "bernoulli"))
site_intercept <- match.arg(site_intercept, c("none","ungrouped","grouped"))
Expand Down Expand Up @@ -96,6 +96,7 @@ jsdm_sim_data <- function(N, S, D = NULL, K = 0L, family, method = c("gllvm", "m
"normal",
"inv_gamma",
"lkj_corr_cholesky",
"lkj_corr",
"student_t",
"cauchy",
"gamma"
Expand All @@ -118,6 +119,7 @@ jsdm_sim_data <- function(N, S, D = NULL, K = 0L, family, method = c("gllvm", "m
"normal" = "rnorm",
"inv_gamma" = "rinvgamma",
"lkj_corr_cholesky" = "rlkj",
"lkj_corr" = "rlkj",
"student_t" = "rstudentt",
"cauchy" = "rcauchy",
"gamma" = "rgamma"
Expand All @@ -141,6 +143,12 @@ jsdm_sim_data <- function(N, S, D = NULL, K = 0L, family, method = c("gllvm", "m
)
fun_args <- as.list(c(fun_arg1, as.numeric(unlist(y[[1]][[1]])[-1])))

if(x == "cor_preds" & grepl("lkj_corr_cholesky\\(",prior$cor_species))
fun_args <- c(fun_args,1)
if(x == "cor_species" & grepl("lkj_corr_cholesky\\(",prior$cor_species))
fun_args <- c(fun_args,1)


return(list(fun_name, fun_args))
})
names(prior_func) <- names(prior_split)
Expand Down Expand Up @@ -175,22 +183,22 @@ jsdm_sim_data <- function(N, S, D = NULL, K = 0L, family, method = c("gllvm", "m

# covariate parameters
if(beta_param == "cor"){
beta_sds <- abs(do.call(
sigmas_preds <- abs(do.call(
match.fun(prior_func[["sigmas_preds"]][[1]]),
prior_func[["sigmas_preds"]][[2]]
))
z_betas <- matrix(do.call(
z_preds <- matrix(do.call(
match.fun(prior_func[["z_preds"]][[1]]),
prior_func[["z_preds"]][[2]]
), ncol = S, nrow = J)
if (K == 0) {
beta_sim <- beta_sds %*% z_betas
beta_sim <- sigmas_preds %*% z_preds
} else {
cor_preds <- do.call(
match.fun(prior_func[["cor_preds"]][[1]]),
prior_func[["cor_preds"]][[2]]
)
beta_sim <- (diag(beta_sds) %*% cor_preds) %*% z_betas
beta_sim <- (diag(sigmas_preds) %*% cor_preds) %*% z_preds
}
} else if (beta_param == "unstruct"){
beta_sim <- matrix(do.call(
Expand Down Expand Up @@ -223,15 +231,15 @@ jsdm_sim_data <- function(N, S, D = NULL, K = 0L, family, method = c("gllvm", "m

if (method == "mglmm") {
# u covariance
u_sds <- abs(do.call(
sigmas_species <- abs(do.call(
match.fun(prior_func[["sigmas_species"]][[1]]),
prior_func[["sigmas_species"]][[2]]
))
u_ftilde <- matrix(do.call(
z_species <- matrix(do.call(
match.fun(prior_func[["z_species"]][[1]]),
prior_func[["z_species"]][[2]]
), nrow = S, ncol = N)
u_ij <- t((diag(u_sds) %*% cor_species) %*% u_ftilde)
u_ij <- t((diag(sigmas_species) %*% cor_species) %*% z_species)
}

if (method == "gllvm") {
Expand All @@ -254,7 +262,7 @@ jsdm_sim_data <- function(N, S, D = NULL, K = 0L, family, method = c("gllvm", "m
}
}

L_sigma <- abs(do.call(
sigma_L <- abs(do.call(
match.fun(prior_func[["sigma_L"]][[1]]),
prior_func[["sigma_L"]][[2]]
))
Expand All @@ -264,7 +272,7 @@ jsdm_sim_data <- function(N, S, D = NULL, K = 0L, family, method = c("gllvm", "m
prior_func[["LV"]][[2]]
), nrow = D, ncol = N)

LV_sum <- (L * L_sigma) %*% LV
LV_sum <- (L * sigma_L) %*% LV
}

# variance parameters
Expand Down Expand Up @@ -309,8 +317,8 @@ jsdm_sim_data <- function(N, S, D = NULL, K = 0L, family, method = c("gllvm", "m
)

if(beta_param == "cor"){
pars$beta_sds <- beta_sds
pars$z_betas <- z_betas
pars$sigmas_preds <- sigmas_preds
pars$z_preds <- z_preds
}

if (site_intercept == "ungrouped") {
Expand All @@ -321,12 +329,18 @@ jsdm_sim_data <- function(N, S, D = NULL, K = 0L, family, method = c("gllvm", "m
if (method == "gllvm") {
pars$L <- L
pars$LV <- LV
pars$L_sigma <- L_sigma
pars$sigma_L <- sigma_L
}
if (method == "mglmm") {
pars$u_sds <- u_sds
pars$sigmas_species <- sigmas_species
pars$cor_species <- cor_species
pars$u_ftilde <- u_ftilde
pars$z_species <- z_species
}
if (response == "gaussian") {
pars$sigma <- sigma
}
if (response == "neg_binomial") {
pars$kappa <- kappa
}
if (isTRUE(species_intercept)) {
if (K > 0) {
Expand Down Expand Up @@ -440,7 +454,7 @@ rgbeta <-
#' @rdname sim_helpers
#' @export
rlkj <-
function(n, eta = 1, cholesky = TRUE) {
function(n, eta = 1, cholesky = FALSE) {
if (n < 2) {
stop("Dimension of correlation matrix must be >= 2")
}
Expand Down
4 changes: 2 additions & 2 deletions R/stan_jsdm.R
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
#' quantities (by default TRUE), required for loo
#'
#' @param beta_param The parameterisation of the environmental covariate effects, by
#' default \code{"cor"}. See details for further information.
#' default \code{"unstruct"}. See details for further information.
#'
#' @param ... Arguments passed to [rstan::sampling()]
#'
Expand Down Expand Up @@ -97,7 +97,7 @@ stan_jsdm <- function(X, ...) UseMethod("stan_jsdm")
stan_jsdm.default <- function(X = NULL, Y = NULL, species_intercept = TRUE, method,
dat_list = NULL, family, site_intercept = "none",
D = NULL, prior = jsdm_prior(), site_groups = NULL,
beta_param = "cor",
beta_param = "unstruct",
save_data = TRUE, iter = 4000, log_lik = TRUE, ...) {
family <- match.arg(family, c("gaussian", "bernoulli", "poisson", "neg_binomial"))
beta_param <- match.arg(beta_param, c("cor", "unstruct"))
Expand Down
10 changes: 5 additions & 5 deletions man/jsdm_prior.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions man/jsdm_sim_data.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/sim_helpers.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions man/stan_jsdm.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

37 changes: 31 additions & 6 deletions tests/testthat/test-posterior_predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ bern_pred_data <- matrix(rnorm(100 * 2), nrow = 100)
colnames(bern_pred_data) <- c("V1", "V2")
suppressWarnings(bern_fit <- stan_gllvm(
dat_list = bern_sim_data, family = "bern",
refresh = 0, iter = 1000, chains = 2
refresh = 0, iter = 500, chains = 2
))

test_that("posterior linpred errors appropriately", {
Expand Down Expand Up @@ -65,8 +65,8 @@ test_that("posterior_(lin)pred works with gllvm", {
expect_false(any(sapply(bern_pred, function(x) x < 0)))

bern_pred2 <- posterior_predict(bern_fit,
newdata = bern_pred_data,
ndraws = 50, list_index = "species"
newdata = bern_pred_data,
ndraws = 50, list_index = "species"
)

expect_length(bern_pred2, 9)
Expand All @@ -79,7 +79,7 @@ pois_pred_data <- matrix(rnorm(100 * 2), nrow = 100)
colnames(pois_pred_data) <- c("V1", "V2")
suppressWarnings(pois_fit <- stan_mglmm(
dat_list = pois_sim_data, family = "pois",
refresh = 0, chains = 2
refresh = 0, chains = 2, iter = 500
))
test_that("posterior_(lin)pred works with mglmm", {
pois_pred <- posterior_predict(pois_fit, ndraws = 100)
Expand All @@ -89,11 +89,36 @@ test_that("posterior_(lin)pred works with mglmm", {
expect_false(any(sapply(pois_pred, function(x) x < 0)))

pois_pred2 <- posterior_predict(pois_fit,
newdata = pois_pred_data,
ndraws = 50, list_index = "species"
newdata = pois_pred_data,
ndraws = 50, list_index = "species"
)

expect_length(pois_pred2, 9)
expect_false(any(sapply(pois_pred2, anyNA)))
expect_false(any(sapply(pois_pred2, function(x) x < 0)))
})

negb_sim_data <- mglmm_sim_data(N = 100, S = 9, K = 2, family = "neg_bin",
site_intercept = "ungrouped")
negb_pred_data <- matrix(rnorm(100 * 2), nrow = 100)
colnames(negb_pred_data) <- c("V1", "V2")
suppressWarnings(negb_fit <- stan_mglmm(
dat_list = negb_sim_data, family = "neg_bin",
refresh = 0, chains = 2, iter = 500
))
test_that("posterior_(lin)pred works with mglmm and negbin", {
negb_pred <- posterior_predict(negb_fit, ndraws = 100)

expect_length(negb_pred, 100)
expect_false(any(sapply(negb_pred, anyNA)))
expect_false(any(sapply(negb_pred, function(x) x < 0)))

negb_pred2 <- posterior_predict(negb_fit,
newdata = negb_pred_data,
ndraws = 50, list_index = "species"
)

expect_length(negb_pred2, 9)
expect_false(any(sapply(negb_pred2, anyNA)))
expect_false(any(sapply(negb_pred2, function(x) x < 0)))
})
17 changes: 17 additions & 0 deletions tests/testthat/test-sim_data_funs.R
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,23 @@ test_that("mglmm_sim_data returns a list of correct length", {
))
})

test_that("jsdm_sim_data returns all appropriate pars", {
mglmm_sim <- jsdm_sim_data(100,12,family = "gaussian", method = "mglmm",
beta_param = "cor")
expect_named(mglmm_sim$pars, c(
"betas","sigmas_preds","z_preds","sigmas_species",
"cor_species","z_species","sigma"
))
gllvm_sim <- jsdm_sim_data(100,12,D=2,family = "neg_bin", method = "gllvm",
beta_param = "unstruct",
site_intercept = "ungrouped")
expect_named(gllvm_sim$pars, c(
"betas","a_bar","sigma_a","a","L","LV","sigma_L","kappa"
))


})

test_that("prior specification works", {
jsdm_sim <- jsdm_sim_data(
N = 100, S = 8, K = 2, family = "gaus", method = "mglmm",
Expand Down
Loading

0 comments on commit 44f40b1

Please sign in to comment.