Skip to content

Commit

Permalink
Make Stan implementations more modular (#1)
Browse files Browse the repository at this point in the history
* Rename target_densities to targets

* Merge dirichlet and log-dirichlet definitions

* Remove function block for easy inclusion

* Move log-Dirichlet to its own stanfunction

* Add model blocks file for Dirichlet

* Add multi-logit-normal implementation

Adapted from mjhajharia#74

* Make `log_` a prefix

* Fix density of multi-logit-normal

* Fix docstring

* Move transforms to blocks subdirectory

* Move ALR functions code to stanfunctions file

* Rename to stanfunctions extension

* Move ExpandedSoftmax functions to stanfunctions file

* Fix target densities

* Run stanc formatter on targets

* Run stanc formatter on transforms

* Make targets completely modular

* Rename transforms files

* Make transforms modular

* Remove outdated stan_models directory

* Strip trailing newlines

* Prefix targets/transforms with name

This is necessary for using Stan includes

* Run formatter

* Remove leftover blocks

* Fix log-Dirichlet normalization factor

* Update test for transforms

* Fix function file name

* Test also multi-logit-normal

* Add data-only argument qualifiers for ILR's V matrix

* Update ALR constructor call

* Test StanStickbreaking

* Use append_row in ALR

* Declare real variables only where defined

* Set seed for sampling and increase sigfigs

* Add docstring to log_dirichlet_lpdf

* Make log_dirichlet_lpdf define the return while returning

For consistency with the multi-logit-normal variants

* Implicitly increment target outside of loops

* Fix comment

* Retain all sig-figs, and choose new stan seed

* Concentrate draws towards middle of the simplex

This is where transforms are most likely to be numerically stable, so the small numerical differences between jax and stan transforms less likely to be relevant.
  • Loading branch information
sethaxen authored Jun 19, 2024
1 parent 079a686 commit 36a599d
Show file tree
Hide file tree
Showing 63 changed files with 601 additions and 683 deletions.
202 changes: 133 additions & 69 deletions jax_transforms/tests/test_stan_transforms.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import os
import tempfile
from typing import NamedTuple

import arviz as az
import cmdstanpy
import jax
import jax.numpy as jnp
import numpy as np
import pytest
from tensorflow_probability.substrates.jax import distributions as tfd

import jax_transforms

Expand All @@ -14,6 +16,7 @@
basic_transforms = [
"ALR",
"ILR",
"StanStickbreaking",
"StickbreakingAngular",
"StickbreakingLogistic",
"StickbreakingNormal",
Expand All @@ -27,84 +30,145 @@
]

project_dir = os.path.abspath(os.path.join(__file__, "..", "..", ".."))
stan_models_dir = os.path.join(project_dir, "transforms/simplex")
stan_models_logscale_dir = os.path.join(project_dir, "transforms/log_simplex")


@pytest.fixture(scope="module", params=basic_transforms + expanded_transforms)
def transform_and_model(request):
transform_name = request.param
model_file = os.path.join(stan_models_dir, f"{transform_name}.stan")
model_code = open(model_file, "r").read()
model_code = model_code.replace("target_density_lp(x, alpha)", "0")
with tempfile.TemporaryDirectory() as tmpdir:
tmp_model_fn = os.path.join(tmpdir, "model.stan")
with open(tmp_model_fn, "w") as f:
f.write(model_code)
model = cmdstanpy.CmdStanModel(stan_file=tmp_model_fn)
yield transform_name, model


@pytest.fixture(scope="module", params=basic_transforms + expanded_transforms)
def transform_and_model_logscale(request):
transform_name = request.param
model_file = os.path.join(stan_models_logscale_dir, f"{transform_name}.stan")
model_code = open(model_file, "r").read()
model_code = model_code.replace(
"target_density_lp(log_x, alpha)", "sum(log_x[1:N - 1])"
)
with tempfile.TemporaryDirectory() as tmpdir:
tmp_model_fn = os.path.join(tmpdir, "model.stan")
with open(tmp_model_fn, "w") as f:
f.write(model_code)
model = cmdstanpy.CmdStanModel(stan_file=tmp_model_fn)
yield transform_name, model


@pytest.mark.parametrize("N", [3, 5, 10])
def test_stan_and_jax_transforms_consistent(transform_and_model, N):
transform_name, model = transform_and_model
targets_dir = os.path.join(project_dir, "targets")
transforms_dir = os.path.join(project_dir, "transforms")
stan_models = {}


class MultiLogitNormal(NamedTuple):
mu: jax.Array
L_Sigma: jax.Array

@property
def event_shape(self):
return (self.mu.shape[0] + 1,)

def log_prob(self, x):
transform = jax_transforms.ALR()
y = transform.unconstrain(x)
logJ = transform.constrain_with_logdetjac(y)[1]
lp_mvnorm = tfd.MultivariateNormalTriL(
loc=self.mu, scale_tril=self.L_Sigma
).log_prob(y)
return lp_mvnorm - logJ


def make_dirichlet_data(N: int, seed: int = 638):
rng = np.random.default_rng(seed)
alpha = rng.uniform(size=N)
return {"N": N, "alpha": np.around(10 * alpha, 4).tolist()}


def make_multi_logit_normal_data(N: int, seed: int = 638):
rng = np.random.default_rng(seed)
mu = 0.01 * rng.normal(size=N - 1)
L_Sigma = np.tril(rng.normal(size=(N - 1, N - 1)))
diaginds = np.diag_indices(N - 1)
L_Sigma[diaginds] = np.abs(L_Sigma[diaginds])
sigma = 100 * np.random.uniform(size=N - 1)
L_Sigma = np.diag(sigma / np.linalg.norm(L_Sigma, axis=1)) @ L_Sigma
return {
"N": N,
"mu": np.around(mu, 4).tolist(),
"L_Sigma": np.around(L_Sigma, 4).tolist(),
}


def make_model_data(target: str, *args, **kwargs):
if target == "dirichlet":
return make_dirichlet_data(*args, **kwargs)
elif target == "multi-logit-normal":
return make_multi_logit_normal_data(*args, **kwargs)
else:
raise ValueError(f"Unknown target {target}")


def make_jax_distribution(target: str, params: dict):
if target == "dirichlet":
return tfd.Dirichlet(jnp.array(params["alpha"]))
elif target == "multi-logit-normal":
return MultiLogitNormal(
mu=jnp.array(params["mu"]), L_Sigma=jnp.array(params["L_Sigma"])
)
else:
raise ValueError(f"Unknown target {target}")


def make_stan_model(
model_file: str, target_name: str, transform_name: str, log_scale: bool
) -> cmdstanpy.CmdStanModel:
target_dir = os.path.join(targets_dir, target_name)
transform_dir = os.path.join(transforms_dir, transform_name)
space = "log_simplex" if log_scale else "simplex"
model_code = f"""
functions {{
#include {target_name}_functions.stan
#include {transform_name}_functions.stan
}}
#include {target_name}_data.stan
#include {transform_name}_parameters_{space}.stan
#include {target_name}_model_{space}.stan
"""
with open(model_file, "w") as f:
f.write(model_code)
stanc_options = {"include-paths": ",".join([target_dir, transform_dir])}
model = cmdstanpy.CmdStanModel(stan_file=model_file, stanc_options=stanc_options)
return model


@pytest.mark.parametrize("N", [3, 5])
@pytest.mark.parametrize("log_scale", [False, True])
@pytest.mark.parametrize("target_name", ["dirichlet", "multi-logit-normal"])
@pytest.mark.parametrize("transform_name", basic_transforms + expanded_transforms)
def test_stan_and_jax_transforms_consistent(
tmpdir, transform_name, target_name, N, log_scale, seed=638, stan_seed=348
):
if transform_name == "StanStickbreaking":
jax_transform_name = "StickbreakingLogistic"
else:
jax_transform_name = transform_name
try:
trans = getattr(jax_transforms, transform_name)()
trans = getattr(jax_transforms, jax_transform_name)()
except AttributeError:
pytest.skip(f"No JAX implementation of {transform_name}. Skipping.")
constrain_with_logdetjac_vec = jax.vmap(
jax.vmap(trans.constrain_with_logdetjac, 0), 0
)
data = {"N": N, "alpha": [1.0] * N}

result = model.sample(data=data, iter_sampling=100)
idata = az.convert_to_inference_data(result)
if target_name != "dirichlet" and transform_name not in ["ALR", "ILR"]:
pytest.skip(f"No need to test {transform_name} with {target_name}. Skipping.")

x_expected, lp_expected = constrain_with_logdetjac_vec(idata.posterior.y.data)
if transform_name in expanded_transforms:
r_expected, x_expected = x_expected
lp_expected += trans.default_prior(x_expected).log_prob(r_expected)

assert jnp.allclose(x_expected, idata.posterior.x.data, atol=1e-5)
assert jnp.allclose(lp_expected, idata.sample_stats.lp.data, atol=1e-5)


@pytest.mark.parametrize("N", [3, 5, 10])
def test_stan_and_jax_transforms_consistent_logscale(transform_and_model_logscale, N):
transform_name, model = transform_and_model_logscale
try:
trans = getattr(jax_transforms, transform_name)()
except AttributeError:
pytest.skip(f"No JAX implementation of {transform_name}. Skipping.")
constrain_with_logdetjac_vec = jax.vmap(
jax.vmap(trans.constrain_with_logdetjac, 0), 0
)
data = {"N": N, "alpha": [1.0] * N}

result = model.sample(data=data, iter_sampling=100)
data = make_model_data(target_name, N, seed=seed)
dist = make_jax_distribution(target_name, data)
log_prob = dist.log_prob

# get compiled model or compile and add to cache
model_key = (target_name, transform_name, log_scale)
if model_key not in stan_models:
model = make_stan_model(
os.path.join(
tmpdir,
f"{target_name}_{transform_name}_{'log_simplex' if log_scale else 'simplex'}.stan",
),
target_name,
transform_name,
log_scale,
)
stan_models[model_key] = model
else:
model = stan_models[(target_name, transform_name, log_scale)]

result = model.sample(data=data, iter_sampling=100, sig_figs=18, seed=stan_seed)
idata = az.convert_to_inference_data(result)

x_expected, lp_expected = constrain_with_logdetjac_vec(idata.posterior.y.data)
if transform_name == "StanStickbreaking":
y = trans.unconstrain(idata.posterior.x.data)
else:
y = idata.posterior.y.data
x_expected, lp_expected = constrain_with_logdetjac_vec(y)
if transform_name in expanded_transforms:
r_expected, x_expected = x_expected
lp_expected += trans.default_prior(x_expected).log_prob(r_expected)
log_x_expected = jnp.log(x_expected)
assert jnp.allclose(log_x_expected, idata.posterior.log_x.data, atol=1e-5)
assert jnp.allclose(x_expected, idata.posterior.x.data, atol=1e-5)
assert jnp.allclose(lp_expected, idata.sample_stats.lp.data, atol=1e-5)
lp_expected += log_prob(x_expected)
assert jnp.allclose(x_expected, idata.posterior.x.data, rtol=1e-4)
assert jnp.allclose(lp_expected, idata.sample_stats.lp.data, rtol=1e-4)
2 changes: 0 additions & 2 deletions stan_models/ALR_DirichletSymmetric.stan

This file was deleted.

2 changes: 0 additions & 2 deletions stan_models/Stickbreaking_DirichletSymmetric.stan

This file was deleted.

5 changes: 0 additions & 5 deletions target_densities/Dirichlet.stan

This file was deleted.

12 changes: 0 additions & 12 deletions target_densities/LogDirichlet.stan

This file was deleted.

4 changes: 4 additions & 0 deletions targets/dirichlet/dirichlet_data.stan
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
data {
int<lower=1> N;
vector<lower=0>[N] alpha;
}
13 changes: 13 additions & 0 deletions targets/dirichlet/dirichlet_functions.stan
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
/**
* Return the Dirichlet density for the specified log simplex.
*
* @param theta a vector on the log simplex (N rows)
* @param alpha prior counts plus one (N rows)
*/
real log_dirichlet_lpdf(vector log_theta, vector alpha) {
int N = rows(log_theta);
if (N != rows(alpha))
reject("Input must contain same number of elements as alpha");
return dot_product(alpha, log_theta) - log_theta[N]
+ lgamma(sum(alpha)) - sum(lgamma(alpha));
}
3 changes: 3 additions & 0 deletions targets/dirichlet/dirichlet_model_log_simplex.stan
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
model {
target += log_dirichlet_lpdf(log_x | alpha);
}
3 changes: 3 additions & 0 deletions targets/dirichlet/dirichlet_model_simplex.stan
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
model {
target += dirichlet_lpdf(x | alpha);
}
5 changes: 5 additions & 0 deletions targets/multi-logit-normal/multi-logit-normal_data.stan
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
data {
int<lower=1> N;
vector[N - 1] mu;
matrix[N - 1, N - 1] L_Sigma;
}
31 changes: 31 additions & 0 deletions targets/multi-logit-normal/multi-logit-normal_functions.stan
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/**
* Return the multivariate logistic normal density for the specified simplex.
*
* See: https://en.wikipedia.org/wiki/Logit-normal_distribution#Multivariate_generalization
*
* @param theta a vector on the simplex (N rows)
* @param mu location of normal (N-1 rows)
* @param L_Sigma Cholesky factor of covariance (N-1 rows, N-1 cols)
*/
real multi_logit_normal_cholesky_lpdf(vector theta, vector mu, matrix L_Sigma) {
int N = rows(theta);
vector[N] log_theta = log(theta);
return multi_normal_cholesky_lpdf(log_theta[1 : N - 1] - log_theta[N] | mu, L_Sigma)
- sum(log_theta);
}

/**
* Return the multivariate logistic normal density for the specified log simplex.
*
* See: https://en.wikipedia.org/wiki/Logit-normal_distribution#Multivariate_generalization
*
* @param theta a vector on the log simplex (N rows)
* @param mu location of normal (N-1 rows)
* @param L_Sigma Cholesky factor of covariance (N-1 rows, N-1 cols)
*/
real log_multi_logit_normal_cholesky_lpdf(vector log_theta, vector mu,
matrix L_Sigma) {
int N = rows(log_theta);
return multi_normal_cholesky_lpdf(log_theta[1 : N - 1] - log_theta[N] | mu, L_Sigma)
- log_theta[N];
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
model {
target += log_multi_logit_normal_cholesky_lpdf(log_x | mu, L_Sigma);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
model {
target += multi_logit_normal_cholesky_lpdf(x | mu, L_Sigma);
}
16 changes: 16 additions & 0 deletions transforms/ALR/ALR_functions.stan
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
vector inv_alr_simplex_constrain_lp(vector y) {
int N = rows(y) + 1;
real r = log1p_exp(log_sum_exp(y));
vector[N] x = append_row(exp(y - r), exp(-r));
target += y;
target += -N * r;
return x;
}

vector inv_alr_log_simplex_constrain_lp(vector y) {
int N = rows(y) + 1;
real r = log1p_exp(log_sum_exp(y));
vector[N] log_x = append_row(y - r, -r);
target += -r;
return log_x;
}
7 changes: 7 additions & 0 deletions transforms/ALR/ALR_parameters_log_simplex.stan
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
parameters {
vector[N - 1] y;
}
transformed parameters {
vector<upper=0>[N] log_x = inv_alr_log_simplex_constrain_lp(y);
simplex[N] x = exp(log_x);
}
6 changes: 6 additions & 0 deletions transforms/ALR/ALR_parameters_simplex.stan
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
parameters {
vector[N - 1] y;
}
transformed parameters {
simplex[N] x = inv_alr_simplex_constrain_lp(y);
}
18 changes: 18 additions & 0 deletions transforms/ExpandedSoftmax/ExpandedSoftmax_functions.stan
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
vector expanded_softmax_simplex_constrain_lp(vector y) {
int N = rows(y);
real r = log_sum_exp(y);
vector[N] x = exp(y - r);
target += y;
target += -N * r;
target += std_normal_lpdf(r - log(N));
return x;
}

vector expanded_softmax_log_simplex_constrain_lp(vector y) {
int N = rows(y);
real r = log_sum_exp(y);
vector[N] log_x = y - r;
target += log_x[N];
target += std_normal_lpdf(r - log(N));
return log_x;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
parameters {
vector[N] y;
}
transformed parameters {
vector<upper=0>[N] log_x = expanded_softmax_log_simplex_constrain_lp(y);
simplex[N] x = exp(log_x);
}
Loading

0 comments on commit 36a599d

Please sign in to comment.