Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

multi logit normal density #74

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open

multi logit normal density #74

wants to merge 2 commits into from

Conversation

bob-carpenter
Copy link
Collaborator

added multi logic normal density example as a function.

I wrote some code to test, but didn't know where to put it if anywhere:

functions {
  /**
   * Return the multivariate logistic normal density for the specified log simplex.
   *
   * See: https://en.wikipedia.org/wiki/Logit-normal_distribution#Multivariate_generalization
   * 
   * @param theta a simplex (N rows)
   * @param mu location of normal (N-1 rows)
   * @param L_Sigma Cholesky factor of covariance (N-1 rows, N-1 cols)
   */
  real multi_logit_normal_cholesky_lpdf(vector theta, vector mu, matrix L_Sigma) {
    int N = rows(theta);
    return sum(-log(theta))
      + multi_normal_cholesky_lpdf(log(theta[1:N - 1] / theta[N]) | mu, L_Sigma);
  }
}
data {
  int<lower=1> N;
  vector[N - 1] mu;
  matrix[N - 1, N - 1] L_Sigma;
}
parameters {
  simplex[N] theta;
}
model {
  theta ~ multi_logit_normal_cholesky(mu, L_Sigma);
}

When I fit that for N = 10, I get

>>> fit.summary()
               Mean      MCSE    StdDev         5%       50%       95%    N_Eff   N_Eff/s     R_hat
lp__      -9.707250  0.118998  2.287330 -13.981800 -9.352490 -6.594010  369.468   9722.84  1.002890
theta[1]   0.101758  0.003245  0.080631   0.017135  0.078687  0.279115  617.543  16251.10  0.999185
theta[2]   0.093105  0.003114  0.065519   0.016626  0.078280  0.222328  442.704  11650.10  0.999218
theta[3]   0.091287  0.002735  0.061166   0.019894  0.078057  0.207519  500.311  13166.10  0.999022
theta[4]   0.086258  0.002170  0.053007   0.023276  0.074910  0.187406  596.702  15702.70  0.999006
theta[5]   0.086327  0.001873  0.052215   0.023367  0.073449  0.183068  777.025  20448.00  1.000590
theta[6]   0.086985  0.001846  0.052987   0.022538  0.075478  0.190856  823.772  21678.20  1.001630
theta[7]   0.087918  0.002201  0.057098   0.021416  0.076411  0.203125  673.113  17713.50  0.999299
theta[8]   0.089039  0.002516  0.059121   0.023527  0.075369  0.205456  552.153  14530.30  0.999014
theta[9]   0.092057  0.003195  0.065958   0.020491  0.075871  0.221874  426.106  11213.30  0.999156
theta[10]  0.095794  0.003665  0.078532   0.017671  0.073461  0.261481  459.097  12081.50  0.999713
theta[11]  0.089472  0.002335  0.061425   0.021900  0.074570  0.211407  691.799  18205.20  1.000490

I'm not 100% sure how to test here without doing simulation-based calibration, but the correlations are along the lines I'd expect:

>>> theta_draws = fit.stan_variable('theta')

>>> np.shape(theta_draws)
(1000, 11)

>>> np.corrcoef(theta_draws[:, 1], theta_draws[:, 2])
array([[1.      , 0.451464],
       [0.451464, 1.      ]])

>>> np.corrcoef(theta_draws[:, 1], theta_draws[:, 3])
array([[1.        , 0.11187134],
       [0.11187134, 1.        ]])

>>> np.corrcoef(theta_draws[:, 1], theta_draws[:, 4])
array([[ 1.       , -0.1824799],
       [-0.1824799,  1.       ]])

>>> np.corrcoef(theta_draws[:, 1], theta_draws[:, 10])
array([[ 1.        , -0.05364937],
       [-0.05364937,  1.        ]])

>>> np.corrcoef(theta_draws[:, 2], theta_draws[:, 3])
array([[1.        , 0.45617112],
       [0.45617112, 1.        ]])

The entry at [0, 1] is the correlation here---sorry for not making this neater, but I closed the session.

Copy link
Collaborator

@sethaxen sethaxen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few minor suggestions. Could you also add the log-simplex version? (taking log_theta as input and with only lp += -log_theta[N] as the adjustment).

target_densities/multi-logit-normal.stan Outdated Show resolved Hide resolved
target_densities/multi-logit-normal.stan Outdated Show resolved Hide resolved
sethaxen added a commit to bob-carpenter/transforms that referenced this pull request Jun 14, 2024
sethaxen added a commit to bob-carpenter/transforms that referenced this pull request Jun 19, 2024
* Rename target_densities to targets

* Merge dirichlet and log-dirichlet definitions

* Remove function block for easy inclusion

* Move log-Dirichlet to its own stanfunction

* Add model blocks file for Dirichlet

* Add multi-logit-normal implementation

Adapted from mjhajharia#74

* Make `log_` a prefix

* Fix density of multi-logit-normal

* Fix docstring

* Move transforms to blocks subdirectory

* Move ALR functions code to stanfunctions file

* Rename to stanfunctions extension

* Move ExpandedSoftmax functions to stanfunctions file

* Fix target densities

* Run stanc formatter on targets

* Run stanc formatter on transforms

* Make targets completely modular

* Rename transforms files

* Make transforms modular

* Remove outdated stan_models directory

* Strip trailing newlines

* Prefix targets/transforms with name

This is necessary for using Stan includes

* Run formatter

* Remove leftover blocks

* Fix log-Dirichlet normalization factor

* Update test for transforms

* Fix function file name

* Test also multi-logit-normal

* Add data-only argument qualifiers for ILR's V matrix

* Update ALR constructor call

* Test StanStickbreaking

* Use append_row in ALR

* Declare real variables only where defined

* Set seed for sampling and increase sigfigs

* Add docstring to log_dirichlet_lpdf

* Make log_dirichlet_lpdf define the return while returning

For consistency with the multi-logit-normal variants

* Implicitly increment target outside of loops

* Fix comment

* Retain all sig-figs, and choose new stan seed

* Concentrate draws towards middle of the simplex

This is where transforms are most likely to be numerically stable, so the small numerical differences between jax and stan transforms less likely to be relevant.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants