Skip to content

Commit

Permalink
Updating problems with binomial
Browse files Browse the repository at this point in the history
Also removing all remnants of phylo model
  • Loading branch information
fseaton committed Dec 20, 2023
1 parent 4b197a1 commit 9b7f25f
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 60 deletions.
15 changes: 9 additions & 6 deletions R/jsdm_stancode.R
Original file line number Diff line number Diff line change
Expand Up @@ -63,26 +63,29 @@ jsdm_stancode <- function(method, family, prior = jsdm_prior(),
data <- paste(
" int<lower=1> N; // Number of sites
int<lower=1> S; // Number of species
", ifelse(method == "gllvm",
",
ifelse(method == "gllvm",
" int<lower=1> D; // Number of latent dimensions", ""
),
ifelse(family == "binomial",
" int<lower=0> Ntrials[N]; // Number of trials",""),
"
int<lower=0> K; // Number of predictor variables
matrix[N, K] X; // Predictor matrix
", ifelse(site_intercept == "grouped",
",
ifelse(site_intercept == "grouped",
"
int<lower=1> ngrp; // Number of groups in site intercept
int<lower=0, upper = ngrp> grps[N]; // Vector matching sites to groups
",""),
",""),
switch(family,
"gaussian" = "real",
"bernoulli" = "int<lower=0,upper=1>",
"neg_binomial" = "int<lower=0>",
"poisson" = "int<lower=0>",
"binomial" = "int<lower=0>"
), "Y[N,S]; //Species matrix"
), "Y[N,S]; //Species matrix",
ifelse(family == "binomial",
"
int<lower=0> Ntrials[N]; // Number of trials","")
)
transformed_data <- ifelse(method == "gllvm", "
// Ensures identifiability of the model - no rotation of factors
Expand Down
8 changes: 6 additions & 2 deletions R/posterior_predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,12 @@ posterior_predict.jsdmStanFit <- function(object, newdata = NULL,
if (object$family == "neg_binomial") {
mod_kappa <- rstan::extract(object$fit, pars = "kappa", permuted = FALSE)
}
if(object$family == "binomial" & is.null(newdata)) {
Ntrials <- object$data_list$Ntrials
if(object$family == "binomial"){
if(is.null(newdata)) {
Ntrials <- object$data_list$Ntrials
} else {
Ntrials <- ntrials_check(Ntrials, nrow(newdata))
}
}

n_sites <- length(object$sites)
Expand Down
29 changes: 17 additions & 12 deletions R/sim_data_funs.R
Original file line number Diff line number Diff line change
Expand Up @@ -96,18 +96,7 @@ jsdm_sim_data <- function(N, S, D = NULL, K = 0L, family, method = c("gllvm", "m
}

if(response == "binomial"){
if(is.null(Ntrials)){
stop("Number of trials must be specified for the binomial distribution")
}
if(!is.double(Ntrials) & !is.integer(Ntrials)){
stop("Ntrials must be a positive integer")
}
if(!(length(Ntrials) %in% c(1, N))){
stop("Ntrials must be of length 1 or N")
}
if(length(Ntrials) == 1L){
Ntrials <- rep(Ntrials, N)
}
Ntrials <- ntrials_check(Ntrials = Ntrials, N = N)
}

# prior object breakdown
Expand Down Expand Up @@ -534,3 +523,19 @@ rinvgamma <- function(n, shape, scale) {
rstudentt <- function(n, df, mu, sigma) {
mu + sigma * stats::rt(n, df = df)
}

ntrials_check <- function(Ntrials, N){
if(is.null(Ntrials)){
stop("Number of trials must be specified for the binomial distribution")
}
if(!is.double(Ntrials) & !is.integer(Ntrials)){
stop("Ntrials must be a positive integer")
}
if(!(length(Ntrials) %in% c(1, N))){
stop("Ntrials must be of length 1 or N")
}
if(length(Ntrials) == 1L){
Ntrials <- rep(Ntrials, N)
}
Ntrials
}
37 changes: 9 additions & 28 deletions R/stan_jsdm.R
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,8 @@ stan_jsdm.default <- function(X = NULL, Y = NULL, species_intercept = TRUE, meth
data_list <- validate_data(
Y = Y, X = X, species_intercept = species_intercept,
D = D, site_intercept = site_intercept, site_groups = site_groups,
dat_list = dat_list, phylo = FALSE,
family = family, method = method, nu05 = "1",
delta = 1e-5, Ntrials = Ntrials
dat_list = dat_list,
family = family, method = method, Ntrials = Ntrials
)

# Create stancode
Expand Down Expand Up @@ -226,8 +225,8 @@ stan_gllvm.formula <- function(formula, data = list(), ...) {
# Internal ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

validate_data <- function(Y, D, X, species_intercept,
dat_list, family, site_intercept, phylo,
method, nu05, delta, site_groups, Ntrials) {
dat_list, family, site_intercept,
method, site_groups, Ntrials) {
method <- match.arg(method, c("gllvm", "mglmm"))

# do things if data not given as list:
Expand All @@ -252,6 +251,10 @@ validate_data <- function(Y, D, X, species_intercept,
X <- matrix(1, nrow = N, ncol = 1)
colnames(X) <- "(Intercept)"
} else {
if(is.null(colnames(X))){
message("No column names specified for X, assigning names")
colnames(X) <- paste0("V",seq_len(ncol(X)))
}
K <- ncol(X) + 1 * species_intercept
if(is.data.frame(X)){
X <- as.matrix(X)
Expand Down Expand Up @@ -284,11 +287,6 @@ validate_data <- function(Y, D, X, species_intercept,
data_list$ngrp <- ngrp
data_list$grps <- grps
}
if (!isFALSE(phylo)) {
data_list$Dmat <- phylo
data_list$nu05 <- nu05
data_list$delta <- delta
}
if(family == "binomial"){
data_list$Ntrials <- Ntrials
}
Expand All @@ -301,12 +299,6 @@ validate_data <- function(Y, D, X, species_intercept,
stop("If supplying data as a list must have a D entry")
}

if (!isFALSE(phylo)) {
if (!all(c("Dmat", "nu05", "delta") %in% names(dat_list))) {
stop("Phylo models require Dmat, nu05 and delta in dat_list")
}
}

if (identical(family, "binomial")) {
if (!all(c("Ntrials") %in% names(dat_list))) {
stop("Binomial models require Ntrials in dat_list")
Expand Down Expand Up @@ -357,18 +349,7 @@ validate_data <- function(Y, D, X, species_intercept,

# Check if Ntrials is appropriate given
if(identical(family, "binomial")) {
if(is.null(data_list$Ntrials)){
stop("Number of trials must be specified for the binomial distribution")
}
if(!is.double(data_list$Ntrials) & !is.integer(data_list$Ntrials)){
stop("Ntrials must be a positive integer")
}
if(!(length(data_list$Ntrials) %in% c(1, data_list$N))){
stop("Ntrials must be of length 1 or N")
}
if(length(data_list$Ntrials) == 1L){
data_list$Ntrials <- rep(data_list$Ntrials, data_list$N)
}
data_list$Ntrials <- ntrials_check(data_list$Ntrials, data_list$N)
}

return(data_list)
Expand Down
24 changes: 13 additions & 11 deletions R/update.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
#' @param newY New Y data, by default \code{NULL}
#' @param newX New X data, by default \code{NULL}
#' @param newD New number of latent variables, by default \code{NULL}
#' @param newNtrials New number of trials (binomial model only), by default
#' \code{NULL}
#' @param save_data Whether to save the data in the jsdmStanFit object, by default
#' \code{TRUE}
#' @param ... Arguments passed to [rstan::sampling()]
Expand Down Expand Up @@ -49,6 +51,7 @@
#' gllvm_fit2
#' }
update.jsdmStanFit <- function(object, newY = NULL, newX = NULL, newD = NULL,
newNtrials = NULL,
save_data = TRUE, ...) {
if (length(object$data_list) == 0) {
stop("Update requires the original data to be saved in the model object")
Expand Down Expand Up @@ -76,29 +79,28 @@ update.jsdmStanFit <- function(object, newY = NULL, newX = NULL, newD = NULL,
} else{
D <- object$data_list$D
}
if(family == "binomial") {
if(!is.null(newNtrials)){
Ntrials <- ntrials_check(newNtrials, nrow(Y))
} else{
Ntrials <- object$data_list$Ntrials
}
}

species_intercept <- "(Intercept)" %in% colnames(object$data_list$X)

site_intercept <- ifelse("ngrp" %in% names(object$data_list), "grouped",
ifelse("a" %in% get_parnames(object), "ungrouped",
"none"))
site_groups <- if(site_intercept == "grouped"){
object$data_list$grps} else{NULL}
phylo <- object$data_list$phylo
if (!isFALSE(phylo)) {
nu05 <- object$data_list$nu05
delta <- object$data_list$delta
} else {
nu05 <- 0L
delta <- 1e-5
}

# validate data
data_list <- validate_data(
Y = Y, X = X, species_intercept = species_intercept,
D = D, site_intercept = site_intercept, site_groups = site_groups,
dat_list = NULL, phylo = phylo,
family = family, method = method, nu05 = nu05,
delta = delta
dat_list = NULL,
family = family, method = method, Ntrials = Ntrials
)

# get original stan model
Expand Down
13 changes: 12 additions & 1 deletion man/update.jsdmStanFit.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 9b7f25f

Please sign in to comment.