Skip to content

Commit

Permalink
Merge pull request #4 from NERC-CEH/binomial
Browse files Browse the repository at this point in the history
Adding Binomial models
  • Loading branch information
fseaton authored Dec 20, 2023
2 parents d35413a + fe46255 commit 645d05e
Show file tree
Hide file tree
Showing 16 changed files with 393 additions and 163 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: jsdmstan
Title: Fitting jSDMs in Stan
Version: 0.2.0
Version: 0.3.0
Authors@R:
person("Fiona", "Seaton", , "[email protected]", role = c("aut", "cre"),
comment = c(ORCID = "0000-0002-2022-7451"))
Expand Down
45 changes: 24 additions & 21 deletions R/jsdm_stancode.R
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ jsdm_stancode <- function(method, family, prior = jsdm_prior(),
log_lik = TRUE, site_intercept = "none",
beta_param = "cor") {
# checks
family <- match.arg(family, c("gaussian", "bernoulli", "poisson", "neg_binomial"))
family <- match.arg(family, c("gaussian", "bernoulli", "poisson",
"neg_binomial","binomial"))
method <- match.arg(method, c("gllvm", "mglmm"))
beta_param <- match.arg(beta_param, c("cor","unstruct"))
site_intercept <- match.arg(site_intercept, c("none","grouped","ungrouped"))
Expand All @@ -62,23 +63,29 @@ jsdm_stancode <- function(method, family, prior = jsdm_prior(),
data <- paste(
" int<lower=1> N; // Number of sites
int<lower=1> S; // Number of species
", ifelse(method == "gllvm",
"int<lower=1> D; // Number of latent dimensions", ""
",
ifelse(method == "gllvm",
" int<lower=1> D; // Number of latent dimensions", ""
),
"
int<lower=0> K; // Number of predictor variables
matrix[N, K] X; // Predictor matrix
", ifelse(site_intercept == "grouped",
",
ifelse(site_intercept == "grouped",
"
int<lower=1> ngrp; // Number of groups in site intercept
int<lower=0, upper = ngrp> grps[N]; // Vector matching sites to groups
",""),
",""),
switch(family,
"gaussian" = "real",
"bernoulli" = "int<lower=0,upper=1>",
"neg_binomial" = "int<lower=0>",
"poisson" = "int<lower=0>"
), "Y[N,S]; //Species matrix"
"poisson" = "int<lower=0>",
"binomial" = "int<lower=0>"
), "Y[N,S]; //Species matrix",
ifelse(family == "binomial",
"
int<lower=0> Ntrials[N]; // Number of trials","")
)
transformed_data <- ifelse(method == "gllvm", "
// Ensures identifiability of the model - no rotation of factors
Expand Down Expand Up @@ -255,7 +262,8 @@ jsdm_stancode <- function(method, family, prior = jsdm_prior(),
kappa ~ ", prior[["kappa"]], ";
"),
"bern" = "",
"poisson" = ""
"poisson" = "",
"binomial" = ""
)
)
model_pt2 <- paste(
Expand All @@ -265,7 +273,8 @@ jsdm_stancode <- function(method, family, prior = jsdm_prior(),
"gaussian" = "normal(mu[i,], sigma);",
"bernoulli" = "bernoulli_logit(mu[i,]);",
"neg_binomial" = "neg_binomial_2_log(mu[i,], kappa);",
"poisson" = "poisson_log(mu[i,]);"
"poisson" = "poisson_log(mu[i,]);",
"binomial" = "binomial_logit(Ntrials[i], mu[i,]);"
)
)

Expand Down Expand Up @@ -316,18 +325,12 @@ jsdm_stancode <- function(method, family, prior = jsdm_prior(),
for(j in 1:S) {
log_lik[i, j] = ",
switch(family,
"gaussian" = "normal_lpdf",
"bernoulli" = "bernoulli_logit_lpmf",
"neg_binomial" = "neg_binomial_2_log_lpmf",
"poisson" = "poisson_log_lpmf"
),
"(Y[i, j] | linpred[i, j]",
switch(family,
"gaussian" = ", sigma)",
"bernoulli" = ")",
"neg_binomial" = ", kappa)",
"poisson" = ")"
), ";
"gaussian" = "normal_lpdf(Y[i, j] | linpred[i, j], sigma);",
"bernoulli" = "bernoulli_logit_lpmf(Y[i, j] | linpred[i, j]);",
"neg_binomial" = "neg_binomial_2_log_lpmf(Y[i, j] | linpred[i, j], kappa);",
"poisson" = "poisson_log_lpmf(Y[i, j] | linpred[i, j]);",
"binomial" = "binomial_logit_lpmf(Y[i, j] | Ntrials[i], linpred[i, j]);"
),"
}
}
}
Expand Down
46 changes: 34 additions & 12 deletions R/posterior_predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,8 @@ posterior_linpred.jsdmStanFit <- function(object, transform = FALSE,
"gaussian" = x,
"bernoulli" = inv_logit(x),
"poisson" = exp(x),
"neg_binomial" = exp(x)
"neg_binomial" = exp(x),
"binomial" = inv_logit(x)
)
})
}
Expand All @@ -184,6 +185,10 @@ posterior_linpred.jsdmStanFit <- function(object, transform = FALSE,
#'
#' @inheritParams posterior_linpred.jsdmStanFit
#'
#' @param Ntrials For the binomial distribution the number of trials, given as
#' either a single integer which is assumed to be constant across sites or as
#' a site-length vector of integers.
#'
#' @return A list of linear predictors. If list_index is \code{"draws"} (the default)
#' the list will have length equal to the number of draws with each element of
#' the list being a site x species matrix. If the list_index is \code{"species"} the
Expand All @@ -200,7 +205,8 @@ posterior_linpred.jsdmStanFit <- function(object, transform = FALSE,
posterior_predict.jsdmStanFit <- function(object, newdata = NULL,
newdata_type = "X", ndraws = NULL,
draw_ids = NULL,
list_index = "draws", ...) {
list_index = "draws",
Ntrials = NULL, ...) {
transform <- ifelse(object$family == "gaussian", FALSE, TRUE)
post_linpred <- posterior_linpred(object,
newdata = newdata, ndraws = ndraws,
Expand All @@ -214,20 +220,36 @@ posterior_predict.jsdmStanFit <- function(object, newdata = NULL,
if (object$family == "neg_binomial") {
mod_kappa <- rstan::extract(object$fit, pars = "kappa", permuted = FALSE)
}
if(object$family == "binomial"){
if(is.null(newdata)) {
Ntrials <- object$data_list$Ntrials
} else {
Ntrials <- ntrials_check(Ntrials, nrow(newdata))
}
}

n_sites <- length(object$sites)
n_species <- length(object$species)

post_pred <- lapply(post_linpred, function(x, family = object$family) {
x2 <- x
x2 <- apply(x2, 1:2, function(x) {
switch(object$family,
"gaussian" = stats::rnorm(1, x, mod_sigma),
"bernoulli" = stats::rbinom(1, 1, x),
"poisson" = stats::rpois(1, x),
"neg_binomial" = rgampois(1, x, mod_kappa)
)
})
post_pred <- lapply(seq_along(post_linpred),
function(x, family = object$family) {
x2 <- post_linpred[[x]]
if(family == "binomial"){
for(i in 1:nrow(x2)){
for(j in 1:ncol(x2)){
x2[i,j] <- stats::rbinom(1, Ntrials[i], x2[i,j])
}
}
} else {
x2 <- apply(x2, 1:2, function(x) {
switch(object$family,
"gaussian" = stats::rnorm(1, x, mod_sigma),
"bernoulli" = stats::rbinom(1, 1, x),
"poisson" = stats::rpois(1, x),
"neg_binomial" = rgampois(1, x, mod_kappa)
)
})
}
x2
})

Expand Down
91 changes: 62 additions & 29 deletions R/sim_data_funs.R
Original file line number Diff line number Diff line change
@@ -1,29 +1,29 @@
#' Generate simulated data within a variety of jSDM methodologies
#'
#' The \code{jsdm_sim_data} function can simulate data with either a multivariate
#' generalised mixed model (MGLMM) or a generalised linear latent variable model
#' (GLLVM). The \code{gllvm_sim_data} and \code{mglmm_sim_data} are aliases for
#' \code{jsdm_sim_data} that set \code{method} to \code{"gllvm"} and \code{"mglmm"}
#' respectively.
#' The \code{jsdm_sim_data} function can simulate data with either a
#' multivariate generalised mixed model (MGLMM) or a generalised linear latent
#' variable model (GLLVM). The \code{gllvm_sim_data} and \code{mglmm_sim_data}
#' are aliases for \code{jsdm_sim_data} that set \code{method} to \code{"gllvm"}
#' and \code{"mglmm"} respectively.
#'
#' @details This simulates data based on a joint species distribution model with
#' either a generalised linear latent variable model approach or a multivariate
#' generalised linear mixed model approach.
#' either a generalised linear latent variable model approach or a
#' multivariate generalised linear mixed model approach.
#'
#' Models can be fit with or without "measured predictors", and if measured
#' predictors are included then the species have species-specific parameter
#' estimates. These can either be simulated completely independently, or have
#' information pooled across species. If information is pooled this can be modelled
#' as either a random draw from some mean and standard deviation or species
#' covariance can be modelled together (this will be the covariance used in the
#' overall model if the method used has covariance).
#' information pooled across species. If information is pooled this can be
#' modelled as either a random draw from some mean and standard deviation or
#' species covariance can be modelled together (this will be the covariance
#' used in the overall model if the method used has covariance).
#'
#' Environmental covariate effects (\code{"betas"}) can be parameterised in two
#' ways. With the \code{"cor"} parameterisation all covariate effects are assumed
#' to be constrained by a correlation matrix between the covariates. With the
#' \code{"unstruct"} parameterisation all covariate effects are assumed to draw
#' from a simple distribution with no correlation structure. Both parameterisations
#' can be modified using the prior object.
#' Environmental covariate effects (\code{"betas"}) can be parameterised in
#' two ways. With the \code{"cor"} parameterisation all covariate effects are
#' assumed to be constrained by a correlation matrix between the covariates.
#' With the \code{"unstruct"} parameterisation all covariate effects are
#' assumed to draw from a simple distribution with no correlation structure.
#' Both parameterisations can be modified using the prior object.
#'
#' @export
#'
Expand All @@ -36,30 +36,36 @@
#' @param K is number of covariates, by default \code{0}
#'
#' @param family is the response family, must be one of \code{"gaussian"},
#' \code{"neg_binomial"}, \code{"poisson"} or \code{"bernoulli"}. Regular
#' expression matching is supported.
#' \code{"neg_binomial"}, \code{"poisson"}, \code{"binomial"},
#' or \code{"bernoulli"}. Regular expression matching is supported.
#'
#' @param method is the jSDM method to use, currently either \code{"gllvm"} or
#' \code{"mglmm"} - see details for more information.
#'
#' @param species_intercept Whether to include an intercept in the predictors, must
#' be \code{TRUE} if \code{K} is \code{0}. Defaults to \code{TRUE}.
#' @param species_intercept Whether to include an intercept in the predictors,
#' must be \code{TRUE} if \code{K} is \code{0}. Defaults to \code{TRUE}.
#'
#' @param Ntrials For the binomial distribution the number of trials, given as
#' either a single integer which is assumed to be constant across sites or as
#' a site-length vector of integers.
#'
#' @param site_intercept Whether a site intercept should be included, potential
#' values \code{"none"} (no site intercept) or \code{"ungrouped"} (site intercept
#' with no grouping). Defaults to no site intercept, grouped is not supported
#' currently.
#' values \code{"none"} (no site intercept) or \code{"ungrouped"} (site
#' intercept with no grouping). Defaults to no site intercept, grouped is not
#' supported currently.
#'
#' @param beta_param The parameterisation of the environmental covariate effects, by
#' default \code{"unstruct"}. See details for further information.
#' @param beta_param The parameterisation of the environmental covariate
#' effects, by default \code{"unstruct"}. See details for further information.
#'
#' @param prior Set of prior specifications from call to [jsdm_prior()]
jsdm_sim_data <- function(N, S, D = NULL, K = 0L, family, method = c("gllvm", "mglmm"),
species_intercept = TRUE,
Ntrials = NULL,
site_intercept = "none",
beta_param = "unstruct",
prior = jsdm_prior()) {
response <- match.arg(family, c("gaussian", "neg_binomial", "poisson", "bernoulli"))
response <- match.arg(family, c("gaussian", "neg_binomial", "poisson",
"bernoulli", "binomial"))
site_intercept <- match.arg(site_intercept, c("none","ungrouped","grouped"))
beta_param <- match.arg(beta_param, c("cor", "unstruct"))
if(site_intercept == "grouped"){
Expand Down Expand Up @@ -89,6 +95,10 @@ jsdm_sim_data <- function(N, S, D = NULL, K = 0L, family, method = c("gllvm", "m
stop("prior object must be of class jsdmprior, produced by jsdm_prior()")
}

if(response == "binomial"){
Ntrials <- ntrials_check(Ntrials = Ntrials, N = N)
}

# prior object breakdown
prior_split <- lapply(prior, strsplit, split = "\\(|\\)|,")
if (!all(sapply(prior_split, function(x) {
Expand Down Expand Up @@ -305,7 +315,8 @@ jsdm_sim_data <- function(N, S, D = NULL, K = 0L, family, method = c("gllvm", "m
),
"gaussian" = stats::rnorm(1, mu_ij, sigma),
"poisson" = stats::rpois(1, exp(mu_ij)),
"bernoulli" = stats::rbinom(1, 1, inv_logit(mu_ij))
"bernoulli" = stats::rbinom(1, 1, inv_logit(mu_ij)),
"binomial" = stats::rbinom(1, Ntrials[i], inv_logit(mu_ij))
)
}
}
Expand All @@ -319,6 +330,9 @@ jsdm_sim_data <- function(N, S, D = NULL, K = 0L, family, method = c("gllvm", "m
if(beta_param == "cor"){
pars$sigmas_preds <- sigmas_preds
pars$z_preds <- z_preds
if(K != 0){
pars$cor_preds <- cor_preds
}
}

if (site_intercept == "ungrouped") {
Expand Down Expand Up @@ -352,6 +366,9 @@ jsdm_sim_data <- function(N, S, D = NULL, K = 0L, family, method = c("gllvm", "m
output <- list(
Y = Y, pars = pars, N = N, S = S, D = D, K = J, X = x
)
if(response == "binomial"){
output$Ntrials <- Ntrials
}

return(output)
}
Expand Down Expand Up @@ -431,7 +448,7 @@ rgampois <- function(n, mu, scale) {
}

inv_logit <- function(x) {
1 / (1 + exp(x))
1 / (1 + exp(-x))
}


Expand Down Expand Up @@ -506,3 +523,19 @@ rinvgamma <- function(n, shape, scale) {
rstudentt <- function(n, df, mu, sigma) {
mu + sigma * stats::rt(n, df = df)
}

ntrials_check <- function(Ntrials, N){
if(is.null(Ntrials)){
stop("Number of trials must be specified for the binomial distribution")
}
if(!is.double(Ntrials) & !is.integer(Ntrials)){
stop("Ntrials must be a positive integer")
}
if(!(length(Ntrials) %in% c(1, N))){
stop("Ntrials must be of length 1 or N")
}
if(length(Ntrials) == 1L){
Ntrials <- rep(Ntrials, N)
}
Ntrials
}
Loading

0 comments on commit 645d05e

Please sign in to comment.