forked from mjhajharia/transforms
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Make Stan implementations more modular (#1)
* Rename target_densities to targets * Merge dirichlet and log-dirichlet definitions * Remove function block for easy inclusion * Move log-Dirichlet to its own stanfunction * Add model blocks file for Dirichlet * Add multi-logit-normal implementation Adapted from mjhajharia#74 * Make `log_` a prefix * Fix density of multi-logit-normal * Fix docstring * Move transforms to blocks subdirectory * Move ALR functions code to stanfunctions file * Rename to stanfunctions extension * Move ExpandedSoftmax functions to stanfunctions file * Fix target densities * Run stanc formatter on targets * Run stanc formatter on transforms * Make targets completely modular * Rename transforms files * Make transforms modular * Remove outdated stan_models directory * Strip trailing newlines * Prefix targets/transforms with name This is necessary for using Stan includes * Run formatter * Remove leftover blocks * Fix log-Dirichlet normalization factor * Update test for transforms * Fix function file name * Test also multi-logit-normal * Add data-only argument qualifiers for ILR's V matrix * Update ALR constructor call * Test StanStickbreaking * Use append_row in ALR * Declare real variables only where defined * Set seed for sampling and increase sigfigs * Add docstring to log_dirichlet_lpdf * Make log_dirichlet_lpdf define the return while returning For consistency with the multi-logit-normal variants * Implicitly increment target outside of loops * Fix comment * Retain all sig-figs, and choose new stan seed * Concentrate draws towards middle of the simplex This is where transforms are most likely to be numerically stable, so the small numerical differences between jax and stan transforms less likely to be relevant.
- Loading branch information
Showing
63 changed files
with
601 additions
and
683 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
data { | ||
int<lower=1> N; | ||
vector<lower=0>[N] alpha; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
/** | ||
* Return the Dirichlet density for the specified log simplex. | ||
* | ||
* @param theta a vector on the log simplex (N rows) | ||
* @param alpha prior counts plus one (N rows) | ||
*/ | ||
real log_dirichlet_lpdf(vector log_theta, vector alpha) { | ||
int N = rows(log_theta); | ||
if (N != rows(alpha)) | ||
reject("Input must contain same number of elements as alpha"); | ||
return dot_product(alpha, log_theta) - log_theta[N] | ||
+ lgamma(sum(alpha)) - sum(lgamma(alpha)); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
model { | ||
target += log_dirichlet_lpdf(log_x | alpha); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
model { | ||
target += dirichlet_lpdf(x | alpha); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
data { | ||
int<lower=1> N; | ||
vector[N - 1] mu; | ||
matrix[N - 1, N - 1] L_Sigma; | ||
} |
31 changes: 31 additions & 0 deletions
31
targets/multi-logit-normal/multi-logit-normal_functions.stan
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
/** | ||
* Return the multivariate logistic normal density for the specified simplex. | ||
* | ||
* See: https://en.wikipedia.org/wiki/Logit-normal_distribution#Multivariate_generalization | ||
* | ||
* @param theta a vector on the simplex (N rows) | ||
* @param mu location of normal (N-1 rows) | ||
* @param L_Sigma Cholesky factor of covariance (N-1 rows, N-1 cols) | ||
*/ | ||
real multi_logit_normal_cholesky_lpdf(vector theta, vector mu, matrix L_Sigma) { | ||
int N = rows(theta); | ||
vector[N] log_theta = log(theta); | ||
return multi_normal_cholesky_lpdf(log_theta[1 : N - 1] - log_theta[N] | mu, L_Sigma) | ||
- sum(log_theta); | ||
} | ||
|
||
/** | ||
* Return the multivariate logistic normal density for the specified log simplex. | ||
* | ||
* See: https://en.wikipedia.org/wiki/Logit-normal_distribution#Multivariate_generalization | ||
* | ||
* @param theta a vector on the log simplex (N rows) | ||
* @param mu location of normal (N-1 rows) | ||
* @param L_Sigma Cholesky factor of covariance (N-1 rows, N-1 cols) | ||
*/ | ||
real log_multi_logit_normal_cholesky_lpdf(vector log_theta, vector mu, | ||
matrix L_Sigma) { | ||
int N = rows(log_theta); | ||
return multi_normal_cholesky_lpdf(log_theta[1 : N - 1] - log_theta[N] | mu, L_Sigma) | ||
- log_theta[N]; | ||
} |
3 changes: 3 additions & 0 deletions
3
targets/multi-logit-normal/multi-logit-normal_model_log_simplex.stan
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
model { | ||
target += log_multi_logit_normal_cholesky_lpdf(log_x | mu, L_Sigma); | ||
} |
3 changes: 3 additions & 0 deletions
3
targets/multi-logit-normal/multi-logit-normal_model_simplex.stan
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
model { | ||
target += multi_logit_normal_cholesky_lpdf(x | mu, L_Sigma); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
vector inv_alr_simplex_constrain_lp(vector y) { | ||
int N = rows(y) + 1; | ||
real r = log1p_exp(log_sum_exp(y)); | ||
vector[N] x = append_row(exp(y - r), exp(-r)); | ||
target += y; | ||
target += -N * r; | ||
return x; | ||
} | ||
|
||
vector inv_alr_log_simplex_constrain_lp(vector y) { | ||
int N = rows(y) + 1; | ||
real r = log1p_exp(log_sum_exp(y)); | ||
vector[N] log_x = append_row(y - r, -r); | ||
target += -r; | ||
return log_x; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
parameters { | ||
vector[N - 1] y; | ||
} | ||
transformed parameters { | ||
vector<upper=0>[N] log_x = inv_alr_log_simplex_constrain_lp(y); | ||
simplex[N] x = exp(log_x); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
parameters { | ||
vector[N - 1] y; | ||
} | ||
transformed parameters { | ||
simplex[N] x = inv_alr_simplex_constrain_lp(y); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
vector expanded_softmax_simplex_constrain_lp(vector y) { | ||
int N = rows(y); | ||
real r = log_sum_exp(y); | ||
vector[N] x = exp(y - r); | ||
target += y; | ||
target += -N * r; | ||
target += std_normal_lpdf(r - log(N)); | ||
return x; | ||
} | ||
|
||
vector expanded_softmax_log_simplex_constrain_lp(vector y) { | ||
int N = rows(y); | ||
real r = log_sum_exp(y); | ||
vector[N] log_x = y - r; | ||
target += log_x[N]; | ||
target += std_normal_lpdf(r - log(N)); | ||
return log_x; | ||
} |
7 changes: 7 additions & 0 deletions
7
transforms/ExpandedSoftmax/ExpandedSoftmax_parameters_log_simplex.stan
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
parameters { | ||
vector[N] y; | ||
} | ||
transformed parameters { | ||
vector<upper=0>[N] log_x = expanded_softmax_log_simplex_constrain_lp(y); | ||
simplex[N] x = exp(log_x); | ||
} |
Oops, something went wrong.