diff --git a/jax_transforms/tests/test_stan_transforms.py b/jax_transforms/tests/test_stan_transforms.py index f5c75f2..ccc6af0 100644 --- a/jax_transforms/tests/test_stan_transforms.py +++ b/jax_transforms/tests/test_stan_transforms.py @@ -1,11 +1,13 @@ import os -import tempfile +from typing import NamedTuple import arviz as az import cmdstanpy import jax import jax.numpy as jnp +import numpy as np import pytest +from tensorflow_probability.substrates.jax import distributions as tfd import jax_transforms @@ -14,6 +16,7 @@ basic_transforms = [ "ALR", "ILR", + "StanStickbreaking", "StickbreakingAngular", "StickbreakingLogistic", "StickbreakingNormal", @@ -27,84 +30,145 @@ ] project_dir = os.path.abspath(os.path.join(__file__, "..", "..", "..")) -stan_models_dir = os.path.join(project_dir, "transforms/simplex") -stan_models_logscale_dir = os.path.join(project_dir, "transforms/log_simplex") - - -@pytest.fixture(scope="module", params=basic_transforms + expanded_transforms) -def transform_and_model(request): - transform_name = request.param - model_file = os.path.join(stan_models_dir, f"{transform_name}.stan") - model_code = open(model_file, "r").read() - model_code = model_code.replace("target_density_lp(x, alpha)", "0") - with tempfile.TemporaryDirectory() as tmpdir: - tmp_model_fn = os.path.join(tmpdir, "model.stan") - with open(tmp_model_fn, "w") as f: - f.write(model_code) - model = cmdstanpy.CmdStanModel(stan_file=tmp_model_fn) - yield transform_name, model - - -@pytest.fixture(scope="module", params=basic_transforms + expanded_transforms) -def transform_and_model_logscale(request): - transform_name = request.param - model_file = os.path.join(stan_models_logscale_dir, f"{transform_name}.stan") - model_code = open(model_file, "r").read() - model_code = model_code.replace( - "target_density_lp(log_x, alpha)", "sum(log_x[1:N - 1])" - ) - with tempfile.TemporaryDirectory() as tmpdir: - tmp_model_fn = os.path.join(tmpdir, "model.stan") - with open(tmp_model_fn, "w") as f: - f.write(model_code) - model = cmdstanpy.CmdStanModel(stan_file=tmp_model_fn) - yield transform_name, model - - -@pytest.mark.parametrize("N", [3, 5, 10]) -def test_stan_and_jax_transforms_consistent(transform_and_model, N): - transform_name, model = transform_and_model +targets_dir = os.path.join(project_dir, "targets") +transforms_dir = os.path.join(project_dir, "transforms") +stan_models = {} + + +class MultiLogitNormal(NamedTuple): + mu: jax.Array + L_Sigma: jax.Array + + @property + def event_shape(self): + return (self.mu.shape[0] + 1,) + + def log_prob(self, x): + transform = jax_transforms.ALR() + y = transform.unconstrain(x) + logJ = transform.constrain_with_logdetjac(y)[1] + lp_mvnorm = tfd.MultivariateNormalTriL( + loc=self.mu, scale_tril=self.L_Sigma + ).log_prob(y) + return lp_mvnorm - logJ + + +def make_dirichlet_data(N: int, seed: int = 638): + rng = np.random.default_rng(seed) + alpha = rng.uniform(size=N) + return {"N": N, "alpha": np.around(10 * alpha, 4).tolist()} + + +def make_multi_logit_normal_data(N: int, seed: int = 638): + rng = np.random.default_rng(seed) + mu = 0.01 * rng.normal(size=N - 1) + L_Sigma = np.tril(rng.normal(size=(N - 1, N - 1))) + diaginds = np.diag_indices(N - 1) + L_Sigma[diaginds] = np.abs(L_Sigma[diaginds]) + sigma = 100 * np.random.uniform(size=N - 1) + L_Sigma = np.diag(sigma / np.linalg.norm(L_Sigma, axis=1)) @ L_Sigma + return { + "N": N, + "mu": np.around(mu, 4).tolist(), + "L_Sigma": np.around(L_Sigma, 4).tolist(), + } + + +def make_model_data(target: str, *args, **kwargs): + if target == "dirichlet": + return make_dirichlet_data(*args, **kwargs) + elif target == "multi-logit-normal": + return make_multi_logit_normal_data(*args, **kwargs) + else: + raise ValueError(f"Unknown target {target}") + + +def make_jax_distribution(target: str, params: dict): + if target == "dirichlet": + return tfd.Dirichlet(jnp.array(params["alpha"])) + elif target == "multi-logit-normal": + return MultiLogitNormal( + mu=jnp.array(params["mu"]), L_Sigma=jnp.array(params["L_Sigma"]) + ) + else: + raise ValueError(f"Unknown target {target}") + + +def make_stan_model( + model_file: str, target_name: str, transform_name: str, log_scale: bool +) -> cmdstanpy.CmdStanModel: + target_dir = os.path.join(targets_dir, target_name) + transform_dir = os.path.join(transforms_dir, transform_name) + space = "log_simplex" if log_scale else "simplex" + model_code = f""" + functions {{ + #include {target_name}_functions.stan + #include {transform_name}_functions.stan + }} + #include {target_name}_data.stan + #include {transform_name}_parameters_{space}.stan + #include {target_name}_model_{space}.stan + """ + with open(model_file, "w") as f: + f.write(model_code) + stanc_options = {"include-paths": ",".join([target_dir, transform_dir])} + model = cmdstanpy.CmdStanModel(stan_file=model_file, stanc_options=stanc_options) + return model + + +@pytest.mark.parametrize("N", [3, 5]) +@pytest.mark.parametrize("log_scale", [False, True]) +@pytest.mark.parametrize("target_name", ["dirichlet", "multi-logit-normal"]) +@pytest.mark.parametrize("transform_name", basic_transforms + expanded_transforms) +def test_stan_and_jax_transforms_consistent( + tmpdir, transform_name, target_name, N, log_scale, seed=638, stan_seed=348 +): + if transform_name == "StanStickbreaking": + jax_transform_name = "StickbreakingLogistic" + else: + jax_transform_name = transform_name try: - trans = getattr(jax_transforms, transform_name)() + trans = getattr(jax_transforms, jax_transform_name)() except AttributeError: pytest.skip(f"No JAX implementation of {transform_name}. Skipping.") - constrain_with_logdetjac_vec = jax.vmap( - jax.vmap(trans.constrain_with_logdetjac, 0), 0 - ) - data = {"N": N, "alpha": [1.0] * N} - - result = model.sample(data=data, iter_sampling=100) - idata = az.convert_to_inference_data(result) + if target_name != "dirichlet" and transform_name not in ["ALR", "ILR"]: + pytest.skip(f"No need to test {transform_name} with {target_name}. Skipping.") - x_expected, lp_expected = constrain_with_logdetjac_vec(idata.posterior.y.data) - if transform_name in expanded_transforms: - r_expected, x_expected = x_expected - lp_expected += trans.default_prior(x_expected).log_prob(r_expected) - - assert jnp.allclose(x_expected, idata.posterior.x.data, atol=1e-5) - assert jnp.allclose(lp_expected, idata.sample_stats.lp.data, atol=1e-5) - - -@pytest.mark.parametrize("N", [3, 5, 10]) -def test_stan_and_jax_transforms_consistent_logscale(transform_and_model_logscale, N): - transform_name, model = transform_and_model_logscale - try: - trans = getattr(jax_transforms, transform_name)() - except AttributeError: - pytest.skip(f"No JAX implementation of {transform_name}. Skipping.") constrain_with_logdetjac_vec = jax.vmap( jax.vmap(trans.constrain_with_logdetjac, 0), 0 ) - data = {"N": N, "alpha": [1.0] * N} - result = model.sample(data=data, iter_sampling=100) + data = make_model_data(target_name, N, seed=seed) + dist = make_jax_distribution(target_name, data) + log_prob = dist.log_prob + + # get compiled model or compile and add to cache + model_key = (target_name, transform_name, log_scale) + if model_key not in stan_models: + model = make_stan_model( + os.path.join( + tmpdir, + f"{target_name}_{transform_name}_{'log_simplex' if log_scale else 'simplex'}.stan", + ), + target_name, + transform_name, + log_scale, + ) + stan_models[model_key] = model + else: + model = stan_models[(target_name, transform_name, log_scale)] + + result = model.sample(data=data, iter_sampling=100, sig_figs=18, seed=stan_seed) idata = az.convert_to_inference_data(result) - x_expected, lp_expected = constrain_with_logdetjac_vec(idata.posterior.y.data) + if transform_name == "StanStickbreaking": + y = trans.unconstrain(idata.posterior.x.data) + else: + y = idata.posterior.y.data + x_expected, lp_expected = constrain_with_logdetjac_vec(y) if transform_name in expanded_transforms: r_expected, x_expected = x_expected lp_expected += trans.default_prior(x_expected).log_prob(r_expected) - log_x_expected = jnp.log(x_expected) - assert jnp.allclose(log_x_expected, idata.posterior.log_x.data, atol=1e-5) - assert jnp.allclose(x_expected, idata.posterior.x.data, atol=1e-5) - assert jnp.allclose(lp_expected, idata.sample_stats.lp.data, atol=1e-5) + lp_expected += log_prob(x_expected) + assert jnp.allclose(x_expected, idata.posterior.x.data, rtol=1e-4) + assert jnp.allclose(lp_expected, idata.sample_stats.lp.data, rtol=1e-4) diff --git a/stan_models/ALR_DirichletSymmetric.stan b/stan_models/ALR_DirichletSymmetric.stan deleted file mode 100644 index 955ba14..0000000 --- a/stan_models/ALR_DirichletSymmetric.stan +++ /dev/null @@ -1,2 +0,0 @@ -#include ../target_densities/DirichletSymmetric.stan -#include ../transforms/simplex/ALR.stan diff --git a/stan_models/Stickbreaking_DirichletSymmetric.stan b/stan_models/Stickbreaking_DirichletSymmetric.stan deleted file mode 100644 index fbcfd5f..0000000 --- a/stan_models/Stickbreaking_DirichletSymmetric.stan +++ /dev/null @@ -1,2 +0,0 @@ -#include ../target_densities/DirichletSymmetric.stan -#include ../transforms/simplex/Stickbreaking.stan diff --git a/target_densities/Dirichlet.stan b/target_densities/Dirichlet.stan deleted file mode 100644 index 6bc560c..0000000 --- a/target_densities/Dirichlet.stan +++ /dev/null @@ -1,5 +0,0 @@ -functions { - real target_density_lp(vector x, vector alpha){ - return dirichlet_lpdf(x | alpha); - } -} diff --git a/target_densities/LogDirichlet.stan b/target_densities/LogDirichlet.stan deleted file mode 100644 index 47a0188..0000000 --- a/target_densities/LogDirichlet.stan +++ /dev/null @@ -1,12 +0,0 @@ -functions { - real log_dirichlet_lpdf(vector log_theta, vector alpha) { - int N = rows(log_theta); - if (N != rows(alpha)) reject("Input must contain same number of elements as alpha"); - real lp = dot_product(alpha, log_theta) - log_theta[N]; - lp += sum(lgamma(alpha)) - lgamma(sum(alpha)); - return lp; - } - real target_density_lp(vector log_x, vector alpha){ - return log_dirichlet_lpdf(log_x | alpha); - } -} diff --git a/targets/dirichlet/dirichlet_data.stan b/targets/dirichlet/dirichlet_data.stan new file mode 100644 index 0000000..6ac619b --- /dev/null +++ b/targets/dirichlet/dirichlet_data.stan @@ -0,0 +1,4 @@ +data { + int N; + vector[N] alpha; +} diff --git a/targets/dirichlet/dirichlet_functions.stan b/targets/dirichlet/dirichlet_functions.stan new file mode 100644 index 0000000..9cd7fc7 --- /dev/null +++ b/targets/dirichlet/dirichlet_functions.stan @@ -0,0 +1,13 @@ +/** + * Return the Dirichlet density for the specified log simplex. + * + * @param theta a vector on the log simplex (N rows) + * @param alpha prior counts plus one (N rows) + */ +real log_dirichlet_lpdf(vector log_theta, vector alpha) { + int N = rows(log_theta); + if (N != rows(alpha)) + reject("Input must contain same number of elements as alpha"); + return dot_product(alpha, log_theta) - log_theta[N] + + lgamma(sum(alpha)) - sum(lgamma(alpha)); +} diff --git a/targets/dirichlet/dirichlet_model_log_simplex.stan b/targets/dirichlet/dirichlet_model_log_simplex.stan new file mode 100644 index 0000000..70c7e31 --- /dev/null +++ b/targets/dirichlet/dirichlet_model_log_simplex.stan @@ -0,0 +1,3 @@ +model { + target += log_dirichlet_lpdf(log_x | alpha); +} diff --git a/targets/dirichlet/dirichlet_model_simplex.stan b/targets/dirichlet/dirichlet_model_simplex.stan new file mode 100644 index 0000000..5e4b433 --- /dev/null +++ b/targets/dirichlet/dirichlet_model_simplex.stan @@ -0,0 +1,3 @@ +model { + target += dirichlet_lpdf(x | alpha); +} diff --git a/targets/multi-logit-normal/multi-logit-normal_data.stan b/targets/multi-logit-normal/multi-logit-normal_data.stan new file mode 100644 index 0000000..b0624e3 --- /dev/null +++ b/targets/multi-logit-normal/multi-logit-normal_data.stan @@ -0,0 +1,5 @@ +data { + int N; + vector[N - 1] mu; + matrix[N - 1, N - 1] L_Sigma; +} diff --git a/targets/multi-logit-normal/multi-logit-normal_functions.stan b/targets/multi-logit-normal/multi-logit-normal_functions.stan new file mode 100644 index 0000000..a1c6f3a --- /dev/null +++ b/targets/multi-logit-normal/multi-logit-normal_functions.stan @@ -0,0 +1,31 @@ +/** + * Return the multivariate logistic normal density for the specified simplex. + * + * See: https://en.wikipedia.org/wiki/Logit-normal_distribution#Multivariate_generalization + * + * @param theta a vector on the simplex (N rows) + * @param mu location of normal (N-1 rows) + * @param L_Sigma Cholesky factor of covariance (N-1 rows, N-1 cols) + */ +real multi_logit_normal_cholesky_lpdf(vector theta, vector mu, matrix L_Sigma) { + int N = rows(theta); + vector[N] log_theta = log(theta); + return multi_normal_cholesky_lpdf(log_theta[1 : N - 1] - log_theta[N] | mu, L_Sigma) + - sum(log_theta); +} + +/** + * Return the multivariate logistic normal density for the specified log simplex. + * + * See: https://en.wikipedia.org/wiki/Logit-normal_distribution#Multivariate_generalization + * + * @param theta a vector on the log simplex (N rows) + * @param mu location of normal (N-1 rows) + * @param L_Sigma Cholesky factor of covariance (N-1 rows, N-1 cols) + */ +real log_multi_logit_normal_cholesky_lpdf(vector log_theta, vector mu, + matrix L_Sigma) { + int N = rows(log_theta); + return multi_normal_cholesky_lpdf(log_theta[1 : N - 1] - log_theta[N] | mu, L_Sigma) + - log_theta[N]; +} diff --git a/targets/multi-logit-normal/multi-logit-normal_model_log_simplex.stan b/targets/multi-logit-normal/multi-logit-normal_model_log_simplex.stan new file mode 100644 index 0000000..1fd29ff --- /dev/null +++ b/targets/multi-logit-normal/multi-logit-normal_model_log_simplex.stan @@ -0,0 +1,3 @@ +model { + target += log_multi_logit_normal_cholesky_lpdf(log_x | mu, L_Sigma); +} diff --git a/targets/multi-logit-normal/multi-logit-normal_model_simplex.stan b/targets/multi-logit-normal/multi-logit-normal_model_simplex.stan new file mode 100644 index 0000000..35f2186 --- /dev/null +++ b/targets/multi-logit-normal/multi-logit-normal_model_simplex.stan @@ -0,0 +1,3 @@ +model { + target += multi_logit_normal_cholesky_lpdf(x | mu, L_Sigma); +} diff --git a/transforms/ALR/ALR_functions.stan b/transforms/ALR/ALR_functions.stan new file mode 100644 index 0000000..dd18353 --- /dev/null +++ b/transforms/ALR/ALR_functions.stan @@ -0,0 +1,16 @@ +vector inv_alr_simplex_constrain_lp(vector y) { + int N = rows(y) + 1; + real r = log1p_exp(log_sum_exp(y)); + vector[N] x = append_row(exp(y - r), exp(-r)); + target += y; + target += -N * r; + return x; +} + +vector inv_alr_log_simplex_constrain_lp(vector y) { + int N = rows(y) + 1; + real r = log1p_exp(log_sum_exp(y)); + vector[N] log_x = append_row(y - r, -r); + target += -r; + return log_x; +} diff --git a/transforms/ALR/ALR_parameters_log_simplex.stan b/transforms/ALR/ALR_parameters_log_simplex.stan new file mode 100644 index 0000000..1fc2e4c --- /dev/null +++ b/transforms/ALR/ALR_parameters_log_simplex.stan @@ -0,0 +1,7 @@ +parameters { + vector[N - 1] y; +} +transformed parameters { + vector[N] log_x = inv_alr_log_simplex_constrain_lp(y); + simplex[N] x = exp(log_x); +} diff --git a/transforms/ALR/ALR_parameters_simplex.stan b/transforms/ALR/ALR_parameters_simplex.stan new file mode 100644 index 0000000..c0ef345 --- /dev/null +++ b/transforms/ALR/ALR_parameters_simplex.stan @@ -0,0 +1,6 @@ +parameters { + vector[N - 1] y; +} +transformed parameters { + simplex[N] x = inv_alr_simplex_constrain_lp(y); +} diff --git a/transforms/ExpandedSoftmax/ExpandedSoftmax_functions.stan b/transforms/ExpandedSoftmax/ExpandedSoftmax_functions.stan new file mode 100644 index 0000000..c058866 --- /dev/null +++ b/transforms/ExpandedSoftmax/ExpandedSoftmax_functions.stan @@ -0,0 +1,18 @@ +vector expanded_softmax_simplex_constrain_lp(vector y) { + int N = rows(y); + real r = log_sum_exp(y); + vector[N] x = exp(y - r); + target += y; + target += -N * r; + target += std_normal_lpdf(r - log(N)); + return x; +} + +vector expanded_softmax_log_simplex_constrain_lp(vector y) { + int N = rows(y); + real r = log_sum_exp(y); + vector[N] log_x = y - r; + target += log_x[N]; + target += std_normal_lpdf(r - log(N)); + return log_x; +} diff --git a/transforms/ExpandedSoftmax/ExpandedSoftmax_parameters_log_simplex.stan b/transforms/ExpandedSoftmax/ExpandedSoftmax_parameters_log_simplex.stan new file mode 100644 index 0000000..88a80ee --- /dev/null +++ b/transforms/ExpandedSoftmax/ExpandedSoftmax_parameters_log_simplex.stan @@ -0,0 +1,7 @@ +parameters { + vector[N] y; +} +transformed parameters { + vector[N] log_x = expanded_softmax_log_simplex_constrain_lp(y); + simplex[N] x = exp(log_x); +} diff --git a/transforms/ExpandedSoftmax/ExpandedSoftmax_parameters_simplex.stan b/transforms/ExpandedSoftmax/ExpandedSoftmax_parameters_simplex.stan new file mode 100644 index 0000000..9305728 --- /dev/null +++ b/transforms/ExpandedSoftmax/ExpandedSoftmax_parameters_simplex.stan @@ -0,0 +1,6 @@ +parameters { + vector[N] y; +} +transformed parameters { + simplex[N] x = expanded_softmax_simplex_constrain_lp(y); +} diff --git a/transforms/ILR/ILR_functions.stan b/transforms/ILR/ILR_functions.stan new file mode 100644 index 0000000..7ab457e --- /dev/null +++ b/transforms/ILR/ILR_functions.stan @@ -0,0 +1,30 @@ +matrix semiorthogonal_matrix(data int N) { + matrix[N, N - 1] V; + real inv_nrm2; + for (n in 1 : (N - 1)) { + inv_nrm2 = inv_sqrt(n * (n + 1)); + V[1 : n, n] = rep_vector(inv_nrm2, n); + V[n + 1, n] = -n * inv_nrm2; + V[(n + 2) : N, n] = rep_vector(0, N - n - 1); + } + return V; +} + +vector inv_ilr_simplex_constrain_lp(vector y, data matrix V) { + int N = rows(y) + 1; + vector[N] z = V * y; + real r = log_sum_exp(z); + vector[N] x = exp(z - r); + target += z; + target += 0.5 * log(N) - N * r; + return x; +} + +vector inv_ilr_log_simplex_constrain_lp(vector y, data matrix V) { + int N = rows(y) + 1; + vector[N] z = V * y; + real r = log_sum_exp(z); + vector[N] log_x = z - r; + target += 0.5 * log(N) + log_x[N]; + return log_x; +} diff --git a/transforms/ILR/ILR_parameters_log_simplex.stan b/transforms/ILR/ILR_parameters_log_simplex.stan new file mode 100644 index 0000000..8748da3 --- /dev/null +++ b/transforms/ILR/ILR_parameters_log_simplex.stan @@ -0,0 +1,10 @@ +transformed data { + matrix[N, N - 1] V = semiorthogonal_matrix(N); +} +parameters { + vector[N - 1] y; +} +transformed parameters { + vector[N] log_x = inv_ilr_log_simplex_constrain_lp(y, V);; + simplex[N] x = exp(log_x); +} diff --git a/transforms/ILR/ILR_parameters_simplex.stan b/transforms/ILR/ILR_parameters_simplex.stan new file mode 100644 index 0000000..5d9468a --- /dev/null +++ b/transforms/ILR/ILR_parameters_simplex.stan @@ -0,0 +1,9 @@ +transformed data { + matrix[N, N - 1] V = semiorthogonal_matrix(N); +} +parameters { + vector[N - 1] y; +} +transformed parameters { + simplex[N] x = inv_ilr_simplex_constrain_lp(y, V); +} diff --git a/transforms/NormalizedExponential/NormalizedExponential_functions.stan b/transforms/NormalizedExponential/NormalizedExponential_functions.stan new file mode 100644 index 0000000..841dbaf --- /dev/null +++ b/transforms/NormalizedExponential/NormalizedExponential_functions.stan @@ -0,0 +1,30 @@ +real exponential_log_qf(real logp) { + return -log1m_exp(logp); +} + +vector normalized_exponential_simplex_constrain_lp(vector y) { + int N = rows(y); + vector[N] z; + for (i in 1 : N) { + real log_u = std_normal_lcdf(y[i]); + z[i] = log(exponential_log_qf(log_u)); + } + real r = log_sum_exp(z); + vector[N] x = exp(z - r); + target += std_normal_lpdf(y) - lgamma(N); + return x; +} + +vector normalized_exponential_log_simplex_constrain_lp(vector y) { + int N = rows(y); + vector[N] z; + for (i in 1 : N) { + real log_u = std_normal_lcdf(y[i]); + z[i] = log(exponential_log_qf(log_u)); + } + real r = log_sum_exp(z); + vector[N] log_x = z - r; + target += -log_x[1 : N - 1]; + target += std_normal_lpdf(y) - lgamma(N); + return log_x; +} diff --git a/transforms/NormalizedExponential/NormalizedExponential_parameters_log_simplex.stan b/transforms/NormalizedExponential/NormalizedExponential_parameters_log_simplex.stan new file mode 100644 index 0000000..5e959d8 --- /dev/null +++ b/transforms/NormalizedExponential/NormalizedExponential_parameters_log_simplex.stan @@ -0,0 +1,7 @@ +parameters { + vector[N] y; +} +transformed parameters { + vector[N] log_x = normalized_exponential_log_simplex_constrain_lp(y); + simplex[N] x = exp(log_x); +} diff --git a/transforms/NormalizedExponential/NormalizedExponential_parameters_simplex.stan b/transforms/NormalizedExponential/NormalizedExponential_parameters_simplex.stan new file mode 100644 index 0000000..76b95ea --- /dev/null +++ b/transforms/NormalizedExponential/NormalizedExponential_parameters_simplex.stan @@ -0,0 +1,6 @@ +parameters { + vector[N] y; +} +transformed parameters { + simplex[N] x = normalized_exponential_simplex_constrain_lp(y); +} diff --git a/transforms/StanStickbreaking/StanStickbreaking_functions.stan b/transforms/StanStickbreaking/StanStickbreaking_functions.stan new file mode 100644 index 0000000..02bd724 --- /dev/null +++ b/transforms/StanStickbreaking/StanStickbreaking_functions.stan @@ -0,0 +1,6 @@ +vector simplex_to_log_simplex_transform_lp(vector x) { + int N = rows(x); + vector[N] log_x = log(x); + target += -log_x[1 : N - 1]; + return log_x; +} diff --git a/transforms/StanStickbreaking/StanStickbreaking_parameters_log_simplex.stan b/transforms/StanStickbreaking/StanStickbreaking_parameters_log_simplex.stan new file mode 100644 index 0000000..02d2935 --- /dev/null +++ b/transforms/StanStickbreaking/StanStickbreaking_parameters_log_simplex.stan @@ -0,0 +1,6 @@ +parameters { + simplex[N] x; +} +transformed parameters { + vector[N] log_x = simplex_to_log_simplex_transform_lp(x); +} diff --git a/transforms/StanStickbreaking/StanStickbreaking_parameters_simplex.stan b/transforms/StanStickbreaking/StanStickbreaking_parameters_simplex.stan new file mode 100644 index 0000000..622fd6b --- /dev/null +++ b/transforms/StanStickbreaking/StanStickbreaking_parameters_simplex.stan @@ -0,0 +1,3 @@ +parameters { + simplex[N] x; +} diff --git a/transforms/StickbreakingAngular/StickbreakingAngular_functions.stan b/transforms/StickbreakingAngular/StickbreakingAngular_functions.stan new file mode 100644 index 0000000..0de2cc8 --- /dev/null +++ b/transforms/StickbreakingAngular/StickbreakingAngular_functions.stan @@ -0,0 +1,44 @@ +vector stickbricking_angular_simplex_constrain_lp(vector y) { + int N = rows(y) + 1; + vector[N] x; + real s2_prod = 1; + real log_halfpi = log(pi()) - log2(); + int rcounter = 2 * N - 3; + for (i in 1 : (N - 1)) { + real u = log_inv_logit(y[i]); + real log_phi = u + log_halfpi; + real phi = exp(log_phi); + real s = sin(phi); + real c = cos(phi); + x[i] = s2_prod * c ^ 2; + s2_prod *= s ^ 2; + target += log_phi + log1m_exp(u) + rcounter * log(s) + log(c); + rcounter -= 2; + } + x[N] = s2_prod; + target += (N - 1) * log2(); + return x; +} + +vector stickbricking_angular_log_simplex_constrain_lp(vector y) { + int N = rows(y) + 1; + vector[N] log_x; + real log_s2_prod = 0; + real log_halfpi = log(pi()) - log2(); + int rcounter = 2 * N - 3; + for (i in 1 : (N - 1)) { + real log_u = log_inv_logit(y[i]); + real log_phi = log_u + log_halfpi; + real phi = exp(log_phi); + real log_s = log(sin(phi)); + real log_c = log(cos(phi)); + log_x[i] = log_s2_prod + 2 * log_c; + log_s2_prod += 2 * log_s; + target += log_phi + log1m_exp(log_u) + rcounter * log_s + log_c; + target += -log_x[i]; + rcounter -= 2; + } + log_x[N] = log_s2_prod; + target += (N - 1) * log2(); + return log_x; +} diff --git a/transforms/StickbreakingAngular/StickbreakingAngular_parameters_log_simplex.stan b/transforms/StickbreakingAngular/StickbreakingAngular_parameters_log_simplex.stan new file mode 100644 index 0000000..d3c30fa --- /dev/null +++ b/transforms/StickbreakingAngular/StickbreakingAngular_parameters_log_simplex.stan @@ -0,0 +1,7 @@ +parameters { + vector[N - 1] y; +} +transformed parameters { + vector[N] log_x = stickbricking_angular_log_simplex_constrain_lp(y); + simplex[N] x = exp(log_x); +} diff --git a/transforms/StickbreakingAngular/StickbreakingAngular_parameters_simplex.stan b/transforms/StickbreakingAngular/StickbreakingAngular_parameters_simplex.stan new file mode 100644 index 0000000..777409c --- /dev/null +++ b/transforms/StickbreakingAngular/StickbreakingAngular_parameters_simplex.stan @@ -0,0 +1,6 @@ +parameters { + vector[N - 1] y; +} +transformed parameters { + simplex[N] x = stickbricking_angular_simplex_constrain_lp(y); +} diff --git a/transforms/StickbreakingLogistic/StickbreakingLogistic_functions.stan b/transforms/StickbreakingLogistic/StickbreakingLogistic_functions.stan new file mode 100644 index 0000000..8b14d29 --- /dev/null +++ b/transforms/StickbreakingLogistic/StickbreakingLogistic_functions.stan @@ -0,0 +1,29 @@ +vector stickbreaking_logistic_simplex_constrain_lp(vector y) { + int N = rows(y) + 1; + vector[N] x; + real log_cum_prod = 0; + for (i in 1 : N - 1) { + real log_zi = log_inv_logit(y[i] - log(N - i)); // logistic_lcdf(y[i] | log(N - i), 1) + real log_xi = log_cum_prod + log_zi; + x[i] = exp(log_xi); + log_cum_prod += log1m_exp(log_zi); + target += log_xi; + } + x[N] = exp(log_cum_prod); + target += log_cum_prod; + return x; +} + +vector stickbreaking_logistic_log_simplex_constrain_lp(vector y) { + int N = rows(y) + 1; + vector[N] log_x; + real log_cum_prod = 0; + for (i in 1 : (N - 1)) { + real log_z = log_inv_logit(y[i] - log(N - i)); // logistic_lcdf(y[i] | log(N - i), 1) + log_x[i] = log_cum_prod + log_z; + log_cum_prod += log1m_exp(log_z); + } + log_x[N] = log_cum_prod; + target += log_cum_prod; + return log_x; +} diff --git a/transforms/StickbreakingLogistic/StickbreakingLogistic_parameters_log_simplex.stan b/transforms/StickbreakingLogistic/StickbreakingLogistic_parameters_log_simplex.stan new file mode 100644 index 0000000..a43ed3a --- /dev/null +++ b/transforms/StickbreakingLogistic/StickbreakingLogistic_parameters_log_simplex.stan @@ -0,0 +1,7 @@ +parameters { + vector[N - 1] y; +} +transformed parameters { + vector[N] log_x = stickbreaking_logistic_log_simplex_constrain_lp(y); + simplex[N] x = exp(log_x); +} diff --git a/transforms/StickbreakingLogistic/StickbreakingLogistic_parameters_simplex.stan b/transforms/StickbreakingLogistic/StickbreakingLogistic_parameters_simplex.stan new file mode 100644 index 0000000..2d43087 --- /dev/null +++ b/transforms/StickbreakingLogistic/StickbreakingLogistic_parameters_simplex.stan @@ -0,0 +1,6 @@ +parameters { + vector[N - 1] y; +} +transformed parameters { + simplex[N] x = stickbreaking_logistic_simplex_constrain_lp(y); +} diff --git a/transforms/StickbreakingNormal/StickbreakingNormal_functions.stan b/transforms/StickbreakingNormal/StickbreakingNormal_functions.stan new file mode 100644 index 0000000..d963d61 --- /dev/null +++ b/transforms/StickbreakingNormal/StickbreakingNormal_functions.stan @@ -0,0 +1,30 @@ +vector stickbreaking_normal_simplex_constrain_lp(vector y) { + int N = rows(y) + 1; + vector[N] x; + real log_cum_prod = 0; + for (i in 1 : N - 1) { + real wi = y[i] - log(N - i) / 2; + real log_zi = std_normal_lcdf(wi); + real log_xi = log_cum_prod + log_zi; + x[i] = exp(log_xi); + target += std_normal_lpdf(wi) + log_cum_prod; + log_cum_prod += log1m_exp(log_zi); + } + x[N] = exp(log_cum_prod); + return x; +} + +vector stickbreaking_normal_log_simplex_constrain_lp(vector y) { + int N = rows(y) + 1; + vector[N] log_x; + real log_cum_prod = 0; + for (i in 1 : N - 1) { + real wi = y[i] - log(N - i) / 2; + real log_zi = std_normal_lcdf(wi); + log_x[i] = log_cum_prod + log_zi; + target += std_normal_lpdf(wi) - log_zi; + log_cum_prod += log1m_exp(log_zi); + } + log_x[N] = log_cum_prod; + return log_x; +} diff --git a/transforms/StickbreakingNormal/StickbreakingNormal_parameters_log_simplex.stan b/transforms/StickbreakingNormal/StickbreakingNormal_parameters_log_simplex.stan new file mode 100644 index 0000000..4a64a2a --- /dev/null +++ b/transforms/StickbreakingNormal/StickbreakingNormal_parameters_log_simplex.stan @@ -0,0 +1,7 @@ +parameters { + vector[N - 1] y; +} +transformed parameters { + vector[N] log_x = stickbreaking_normal_log_simplex_constrain_lp(y); + simplex[N] x = exp(log_x); +} diff --git a/transforms/StickbreakingNormal/StickbreakingNormal_parameters_simplex.stan b/transforms/StickbreakingNormal/StickbreakingNormal_parameters_simplex.stan new file mode 100644 index 0000000..c28ec17 --- /dev/null +++ b/transforms/StickbreakingNormal/StickbreakingNormal_parameters_simplex.stan @@ -0,0 +1,6 @@ +parameters { + vector[N - 1] y; +} +transformed parameters { + simplex[N] x = stickbreaking_normal_simplex_constrain_lp(y); +} diff --git a/transforms/StickbreakingPowerLogistic/StickbreakingPowerLogistic_functions.stan b/transforms/StickbreakingPowerLogistic/StickbreakingPowerLogistic_functions.stan new file mode 100644 index 0000000..84bb384 --- /dev/null +++ b/transforms/StickbreakingPowerLogistic/StickbreakingPowerLogistic_functions.stan @@ -0,0 +1,34 @@ +vector stickbreaking_power_logistic_simplex_constrain_lp(vector y) { + int N = rows(y) + 1; + vector[N] x; + real log_cum_prod = 0; + for (i in 1 : (N - 1)) { + real log_u = log_inv_logit(y[i]); // logistic_lcdf(y[i] | 0, 1); + real log_w = log_u / (N - i); + real log_z = log1m_exp(log_w); + x[i] = exp(log_cum_prod + log_z); + target += 2 * log_u - y[i]; // logistic_lpdf(y[i] | 0, 1); + log_cum_prod += log1m_exp(log_z); + } + x[N] = exp(log_cum_prod); + target += -lgamma(N); + return x; +} + +vector stickbreaking_power_logistic_log_simplex_constrain_lp(vector y) { + int N = rows(y) + 1; + vector[N] log_x; + real log_cum_prod = 0; + for (i in 1 : (N - 1)) { + real log_u = log_inv_logit(y[i]); // logistic_lcdf(y[i] | 0, 1); + real log_w = log_u / (N - i); + real log_z = log1m_exp(log_w); + log_x[i] = log_cum_prod + log_z; + target += 2 * log_u - y[i]; // logistic_lpdf(y[i] | 0, 1); + target += -log_x[i]; + log_cum_prod += log1m_exp(log_z); + } + log_x[N] = log_cum_prod; + target += -lgamma(N); + return log_x; +} diff --git a/transforms/StickbreakingPowerLogistic/StickbreakingPowerLogistic_parameters_log_simplex.stan b/transforms/StickbreakingPowerLogistic/StickbreakingPowerLogistic_parameters_log_simplex.stan new file mode 100644 index 0000000..96ee03c --- /dev/null +++ b/transforms/StickbreakingPowerLogistic/StickbreakingPowerLogistic_parameters_log_simplex.stan @@ -0,0 +1,7 @@ +parameters { + vector[N - 1] y; +} +transformed parameters { + vector[N] log_x = stickbreaking_power_logistic_log_simplex_constrain_lp(y); + simplex[N] x = exp(log_x); +} diff --git a/transforms/StickbreakingPowerLogistic/StickbreakingPowerLogistic_parameters_simplex.stan b/transforms/StickbreakingPowerLogistic/StickbreakingPowerLogistic_parameters_simplex.stan new file mode 100644 index 0000000..c5abfa6 --- /dev/null +++ b/transforms/StickbreakingPowerLogistic/StickbreakingPowerLogistic_parameters_simplex.stan @@ -0,0 +1,6 @@ +parameters { + vector[N - 1] y; +} +transformed parameters { + simplex[N] x = stickbreaking_power_logistic_simplex_constrain_lp(y); +} diff --git a/transforms/StickbreakingPowerNormal/StickbreakingPowerNormal_functions.stan b/transforms/StickbreakingPowerNormal/StickbreakingPowerNormal_functions.stan new file mode 100644 index 0000000..9021c99 --- /dev/null +++ b/transforms/StickbreakingPowerNormal/StickbreakingPowerNormal_functions.stan @@ -0,0 +1,34 @@ +vector stickbreaking_power_normal_simplex_constrain_lp(vector y) { + int N = rows(y) + 1; + vector[N] x; + real log_cum_prod = 0; + for (i in 1 : (N - 1)) { + real log_u = std_normal_lcdf(y[i]); + real log_w = log_u / (N - i); + real log_z = log1m_exp(log_w); + x[i] = exp(log_cum_prod + log_z); + log_cum_prod += log1m_exp(log_z); + } + x[N] = exp(log_cum_prod); + target += std_normal_lpdf(y); + target += -lgamma(N); + return x; +} + +vector stickbreaking_power_normal_log_simplex_constrain_lp(vector y) { + int N = rows(y) + 1; + vector[N] log_x; + real log_cum_prod = 0; + for (i in 1 : (N - 1)) { + real log_u = std_normal_lcdf(y[i]); + real log_w = log_u / (N - i); + real log_z = log1m_exp(log_w); + log_x[i] = log_cum_prod + log_z; + target += -log_x[i]; + log_cum_prod += log1m_exp(log_z); + } + log_x[N] = log_cum_prod; + target += std_normal_lpdf(y); + target += -lgamma(N); + return log_x; +} diff --git a/transforms/StickbreakingPowerNormal/StickbreakingPowerNormal_parameters_log_simplex.stan b/transforms/StickbreakingPowerNormal/StickbreakingPowerNormal_parameters_log_simplex.stan new file mode 100644 index 0000000..b6befe4 --- /dev/null +++ b/transforms/StickbreakingPowerNormal/StickbreakingPowerNormal_parameters_log_simplex.stan @@ -0,0 +1,7 @@ +parameters { + vector[N - 1] y; +} +transformed parameters { + vector[N] log_x = stickbreaking_power_normal_log_simplex_constrain_lp(y); + simplex[N] x = exp(log_x); +} diff --git a/transforms/StickbreakingPowerNormal/StickbreakingPowerNormal_parameters_simplex.stan b/transforms/StickbreakingPowerNormal/StickbreakingPowerNormal_parameters_simplex.stan new file mode 100644 index 0000000..ed68b14 --- /dev/null +++ b/transforms/StickbreakingPowerNormal/StickbreakingPowerNormal_parameters_simplex.stan @@ -0,0 +1,6 @@ +parameters { + vector[N - 1] y; +} +transformed parameters { + simplex[N] x = stickbreaking_power_normal_simplex_constrain_lp(y); +} diff --git a/transforms/log_simplex/ALR.stan b/transforms/log_simplex/ALR.stan deleted file mode 100644 index 852ad52..0000000 --- a/transforms/log_simplex/ALR.stan +++ /dev/null @@ -1,25 +0,0 @@ -functions { - vector inv_alr_log_simplex_constrain_lp(vector y){ - int N = rows(y) + 1; - real r = log1p_exp(log_sum_exp(y)); - vector[N] log_x; - log_x[1:N - 1] = y - r; - log_x[N] = -r; - target += -r; - return log_x; - } -} -data { - int N; - vector[N] alpha; -} -parameters { - vector[N - 1] y; -} -transformed parameters { - vector[N] log_x = inv_alr_log_simplex_constrain_lp(y); - simplex[N] x = exp(log_x); -} -model { - target += target_density_lp(log_x, alpha); -} diff --git a/transforms/log_simplex/ExpandedSoftmax.stan b/transforms/log_simplex/ExpandedSoftmax.stan deleted file mode 100644 index 56497a6..0000000 --- a/transforms/log_simplex/ExpandedSoftmax.stan +++ /dev/null @@ -1,24 +0,0 @@ -functions{ - vector expanded_softmax_log_simplex_constrain_lp(vector y) { - int N = rows(y); - real r = log_sum_exp(y); - vector[N] log_x = y - r; - target += log_x[N]; - target += std_normal_lpdf(r - log(N)); - return log_x; - } -} -data { - int N; - vector[N] alpha; -} -parameters { - vector[N] y; -} -transformed parameters { - vector[N] log_x = expanded_softmax_log_simplex_constrain_lp(y); - simplex[N] x = exp(log_x); -} -model { - target += target_density_lp(log_x, alpha); -} diff --git a/transforms/log_simplex/ILR.stan b/transforms/log_simplex/ILR.stan deleted file mode 100644 index 8ee923f..0000000 --- a/transforms/log_simplex/ILR.stan +++ /dev/null @@ -1,39 +0,0 @@ -functions { - matrix semiorthogonal_matrix(int N) { - matrix[N, N - 1] V; - real inv_nrm2; - for (n in 1:(N - 1)) { - inv_nrm2 = inv_sqrt(n * (n + 1)); - V[1:n, n] = rep_vector(inv_nrm2, n); - V[n + 1, n] = -n * inv_nrm2; - V[(n + 2):N, n] = rep_vector(0, N - n - 1); - } - return V; - } - - vector inv_ilr_log_simplex_constrain_lp(vector y, matrix V) { - int N = rows(y) + 1; - vector[N] z = V * y; - real r = log_sum_exp(z); - vector[N] log_x = z - r; - target += log_x[N] + 0.5 * log(N); - return log_x; - } -} -data { - int N; - vector[N] alpha; -} -transformed data { - matrix[N, N - 1] V = semiorthogonal_matrix(N); -} -parameters { - vector[N - 1] y; -} -transformed parameters { - vector[N] log_x = inv_ilr_log_simplex_constrain_lp(y, V);; - simplex[N] x = exp(log_x); -} -model { - target += target_density_lp(log_x, alpha); -} diff --git a/transforms/log_simplex/NormalizedExponential.stan b/transforms/log_simplex/NormalizedExponential.stan deleted file mode 100644 index 55d1371..0000000 --- a/transforms/log_simplex/NormalizedExponential.stan +++ /dev/null @@ -1,33 +0,0 @@ -functions { - real exponential_log_qf(real logp){ - return -log1m_exp(logp); - } - vector normalized_exponential_log_simplex_constrain_lp(vector y) { - int N = rows(y); - vector[N] z; - real log_u; - for (i in 1:N) { - log_u = std_normal_lcdf(y[i]); - z[i] = log(exponential_log_qf(log_u)); - } - real r = log_sum_exp(z); - vector[N] log_x = z - r; - target += -log_x[1:N - 1]; - target += std_normal_lpdf(y) - lgamma(N); - return log_x; - } -} -data { - int N; - vector[N] alpha; -} -parameters { - vector[N] y; -} -transformed parameters { - vector[N] log_x = normalized_exponential_log_simplex_constrain_lp(y); - simplex[N] x = exp(log_x); -} -model { - target += target_density_lp(log_x, alpha); -} diff --git a/transforms/log_simplex/StanStickbreaking.stan b/transforms/log_simplex/StanStickbreaking.stan deleted file mode 100644 index 3991414..0000000 --- a/transforms/log_simplex/StanStickbreaking.stan +++ /dev/null @@ -1,14 +0,0 @@ -data { - int N; - vector[N] alpha; -} -parameters { - simplex[N] x; -} -transformed parameters { - vector[N] log_x = log(x); -} -model { - target += -log_x[1:N - 1]; - target += target_density_lp(log_x, alpha); -} diff --git a/transforms/log_simplex/StickbreakingAngular.stan b/transforms/log_simplex/StickbreakingAngular.stan deleted file mode 100644 index 6e03429..0000000 --- a/transforms/log_simplex/StickbreakingAngular.stan +++ /dev/null @@ -1,39 +0,0 @@ -functions { - vector stickbricking_angular_log_simplex_constrain_lp(vector y) { - int N = rows(y) + 1; - vector[N] log_x; - real log_phi, phi, log_u, log_s, log_c; - real log_s2_prod = 0; - real log_halfpi = log(pi()) - log2(); - int rcounter = 2 * N - 3; - for (i in 1:(N-1)) { - log_u = log_inv_logit(y[i]); - log_phi = log_u + log_halfpi; - phi = exp(log_phi); - log_s = log(sin(phi)); - log_c = log(cos(phi)); - log_x[i] = log_s2_prod + 2 * log_c; - log_s2_prod += 2 * log_s; - target += log_phi + log1m_exp(log_u) + rcounter * log_s + log_c; - target += -log_x[i]; - rcounter -= 2; - } - log_x[N] = log_s2_prod; - target += (N - 1) * log2(); - return log_x; - } -} -data { - int N; - vector[N] alpha; -} -parameters { - vector[N - 1] y; -} -transformed parameters { - vector[N] log_x = stickbricking_angular_log_simplex_constrain_lp(y); - simplex[N] x = exp(log_x); -} -model { - target += target_density_lp(log_x, alpha); -} diff --git a/transforms/log_simplex/StickbreakingLogistic.stan b/transforms/log_simplex/StickbreakingLogistic.stan deleted file mode 100644 index 15c0d8c..0000000 --- a/transforms/log_simplex/StickbreakingLogistic.stan +++ /dev/null @@ -1,30 +0,0 @@ -functions { - vector stickbreaking_logistic_log_simplex_constrain_lp(vector y) { - int N = rows(y) + 1; - vector[N] log_x; - real log_z; - real log_cum_prod = 0; - for (i in 1:(N - 1)) { - log_z = log_inv_logit(y[i] - log(N - i)); // logistic_lcdf(y[i] | log(N - i), 1) - log_x[i] = log_cum_prod + log_z; - log_cum_prod += log1m_exp(log_z); - } - log_x[N] = log_cum_prod; - target += log_cum_prod; - return log_x; - } -} -data { - int N; - vector[N] alpha; -} -parameters { - vector[N - 1] y; -} -transformed parameters { - vector[N] log_x = stickbreaking_logistic_log_simplex_constrain_lp(y); - simplex[N] x = exp(log_x); -} -model { - target += target_density_lp(log_x, alpha); -} diff --git a/transforms/log_simplex/StickbreakingNormal.stan b/transforms/log_simplex/StickbreakingNormal.stan deleted file mode 100644 index b68138b..0000000 --- a/transforms/log_simplex/StickbreakingNormal.stan +++ /dev/null @@ -1,31 +0,0 @@ -functions { - vector stickbreaking_normal_log_simplex_constrain_lp(vector y) { - int N = rows(y) + 1; - vector[N] log_x; - real log_zi, wi; - real log_cum_prod = 0; - for (i in 1:N - 1) { - wi = y[i] - log(N - i) / 2; - log_zi = std_normal_lcdf(wi); - log_x[i] = log_cum_prod + log_zi; - target += std_normal_lpdf(wi) - log_zi; - log_cum_prod += log1m_exp(log_zi); - } - log_x[N] = log_cum_prod; - return log_x; - } -} -data { - int N; - vector[N] alpha; -} -parameters { - vector[N - 1] y; -} -transformed parameters { - vector[N] log_x = stickbreaking_normal_log_simplex_constrain_lp(y); - simplex[N] x = exp(log_x); -} -model { - target += target_density_lp(log_x, alpha); -} diff --git a/transforms/log_simplex/StickbreakingPowerLogistic.stan b/transforms/log_simplex/StickbreakingPowerLogistic.stan deleted file mode 100644 index 5d8f841..0000000 --- a/transforms/log_simplex/StickbreakingPowerLogistic.stan +++ /dev/null @@ -1,34 +0,0 @@ -functions { - vector stickbreaking_power_logistic_log_simplex_constrain_lp(vector y) { - int N = rows(y) + 1; - vector[N] log_x; - real log_u, log_w, log_z; - real log_cum_prod = 0; - for (i in 1:(N-1)) { - log_u = log_inv_logit(y[i]); // logistic_lcdf(y[i] | 0, 1); - log_w = log_u / (N - i); - log_z = log1m_exp(log_w); - log_x[i] = log_cum_prod + log_z; - target += 2 * log_u - y[i]; // logistic_lupdf(y[i] | 0, 1); - target += -log_x[i]; - log_cum_prod += log1m_exp(log_z); - } - log_x[N] = log_cum_prod; - target += -lgamma(N); - return log_x; - } -} -data { - int N; - vector[N] alpha; -} -parameters { - vector[N - 1] y; -} -transformed parameters { - vector[N] log_x = stickbreaking_power_logistic_log_simplex_constrain_lp(y); - simplex[N] x = exp(log_x); -} -model { - target += target_density_lp(log_x, alpha); -} diff --git a/transforms/log_simplex/StickbreakingPowerNormal.stan b/transforms/log_simplex/StickbreakingPowerNormal.stan deleted file mode 100644 index 9663a1a..0000000 --- a/transforms/log_simplex/StickbreakingPowerNormal.stan +++ /dev/null @@ -1,34 +0,0 @@ -functions { - vector stickbreaking_power_normal_log_simplex_constrain_lp(vector y) { - int N = rows(y) + 1; - vector[N] log_x; - real log_u, log_w, log_z; - real log_cum_prod = 0; - for (i in 1:(N-1)) { - log_u = std_normal_lcdf(y[i] |); - log_w = log_u / (N - i); - log_z = log1m_exp(log_w); - log_x[i] = log_cum_prod + log_z; - target += std_normal_lpdf(y[i] |); - target += -log_x[i]; - log_cum_prod += log1m_exp(log_z); - } - log_x[N] = log_cum_prod; - target += -lgamma(N); - return log_x; - } -} -data { - int N; - vector[N] alpha; -} -parameters { - vector[N - 1] y; -} -transformed parameters { - vector[N] log_x = stickbreaking_power_normal_log_simplex_constrain_lp(y); - simplex[N] x = exp(log_x); -} -model { - target += target_density_lp(log_x, alpha); -} diff --git a/transforms/simplex/ALR.stan b/transforms/simplex/ALR.stan deleted file mode 100644 index e4973c0..0000000 --- a/transforms/simplex/ALR.stan +++ /dev/null @@ -1,24 +0,0 @@ -functions { - vector inv_alr_simplex_constrain_lp(vector y){ - int N = rows(y) + 1; - real r = log1p_exp(log_sum_exp(y)); - vector[N] x; - x[1:N - 1] = exp(y - r); - x[N] = exp(-r); - target += sum(y) - N * r; - return x; - } -} -data { - int N; - vector[N] alpha; -} -parameters { - vector[N - 1] y; -} -transformed parameters { - simplex[N] x = inv_alr_simplex_constrain_lp(y); -} -model { - target += target_density_lp(x, alpha); -} diff --git a/transforms/simplex/ExpandedSoftmax.stan b/transforms/simplex/ExpandedSoftmax.stan deleted file mode 100644 index 96ddc09..0000000 --- a/transforms/simplex/ExpandedSoftmax.stan +++ /dev/null @@ -1,23 +0,0 @@ -functions{ - vector expanded_softmax_simplex_constrain_lp(vector y) { - int N = rows(y); - real r = log_sum_exp(y); - vector[N] x = exp(y - r); - target += sum(y) - N * r; // target += log(prod(x)) - target += std_normal_lpdf(r - log(N)); - return x; - } -} -data { - int N; - vector[N] alpha; -} -parameters { - vector[N] y; -} -transformed parameters { - simplex[N] x = expanded_softmax_simplex_constrain_lp(y); -} -model { - target += target_density_lp(x, alpha); -} diff --git a/transforms/simplex/ILR.stan b/transforms/simplex/ILR.stan deleted file mode 100644 index a081545..0000000 --- a/transforms/simplex/ILR.stan +++ /dev/null @@ -1,38 +0,0 @@ -functions { - matrix semiorthogonal_matrix(int N) { - matrix[N, N - 1] V; - real inv_nrm2; - for (n in 1:(N - 1)) { - inv_nrm2 = inv_sqrt(n * (n + 1)); - V[1:n, n] = rep_vector(inv_nrm2, n); - V[n + 1, n] = -n * inv_nrm2; - V[(n + 2):N, n] = rep_vector(0, N - n - 1); - } - return V; - } - - vector inv_ilr_simplex_constrain_lp(vector y, matrix V) { - int N = rows(y) + 1; - vector[N] z = V * y; - real r = log_sum_exp(z); - vector[N] x = exp(z - r); - target += sum(z) - N * r + 0.5 * log(N); - return x; - } -} -data { - int N; - vector[N] alpha; -} -transformed data { - matrix[N, N - 1] V = semiorthogonal_matrix(N); -} -parameters { - vector[N - 1] y; -} -transformed parameters { - simplex[N] x = inv_ilr_simplex_constrain_lp(y, V); -} -model { - target += target_density_lp(x, alpha); -} diff --git a/transforms/simplex/NormalizedExponential.stan b/transforms/simplex/NormalizedExponential.stan deleted file mode 100644 index 64e20df..0000000 --- a/transforms/simplex/NormalizedExponential.stan +++ /dev/null @@ -1,32 +0,0 @@ -functions { - real exponential_log_qf(real logp){ - return -log1m_exp(logp); - } - vector normalized_exponential_simplex_constrain_lp(vector y) { - int N = rows(y); - vector[N] z; - real log_u; - for (i in 1:N) { - log_u = std_normal_lcdf(y[i]); - z[i] = log(exponential_log_qf(log_u)); - } - real r = log_sum_exp(z); - vector[N] x = exp(z - r); - target += std_normal_lpdf(y) - lgamma(N); - return x; - } -} - -data { - int N; - vector[N] alpha; -} -parameters { - vector[N] y; -} -transformed parameters { - simplex[N] x = normalized_exponential_simplex_constrain_lp(y); -} -model { - target += target_density_lp(x, alpha); -} diff --git a/transforms/simplex/StanStickbreaking.stan b/transforms/simplex/StanStickbreaking.stan deleted file mode 100644 index 5f5b1e7..0000000 --- a/transforms/simplex/StanStickbreaking.stan +++ /dev/null @@ -1,10 +0,0 @@ -data { - int N; - vector[N] alpha; -} -parameters { - simplex[N] x; -} -model { - target += target_density_lp(x, alpha); -} diff --git a/transforms/simplex/StickbreakingAngular.stan b/transforms/simplex/StickbreakingAngular.stan deleted file mode 100644 index f2e3a54..0000000 --- a/transforms/simplex/StickbreakingAngular.stan +++ /dev/null @@ -1,37 +0,0 @@ -functions { - vector stickbricking_angular_simplex_constrain_lp(vector y) { - int N = rows(y) + 1; - vector[N] x; - real log_phi, phi, u, s, c; - real s2_prod = 1; - real log_halfpi = log(pi()) - log2(); - int rcounter = 2 * N - 3; - for (i in 1:(N-1)) { - u = log_inv_logit(y[i]); - log_phi = u + log_halfpi; - phi = exp(log_phi); - s = sin(phi); - c = cos(phi); - x[i] = s2_prod * c^2; - s2_prod *= s^2; - target += log_phi + log1m_exp(u) + rcounter * log(s) + log(c); - rcounter -= 2; - } - x[N] = s2_prod; - target += (N - 1) * log2(); - return x; - } -} -data { - int N; - vector[N] alpha; -} -parameters { - vector[N - 1] y; -} -transformed parameters { - simplex[N] x = stickbricking_angular_simplex_constrain_lp(y); -} -model { - target += target_density_lp(x, alpha); -} diff --git a/transforms/simplex/StickbreakingLogistic.stan b/transforms/simplex/StickbreakingLogistic.stan deleted file mode 100644 index b9f0449..0000000 --- a/transforms/simplex/StickbreakingLogistic.stan +++ /dev/null @@ -1,31 +0,0 @@ -functions { - vector stickbreaking_logistic_simplex_constrain_lp(vector y) { - int N = rows(y) + 1; - vector[N] x; - real log_zi, log_xi; - real log_cum_prod = 0; - for (i in 1:N - 1) { - log_zi = log_inv_logit(y[i] - log(N - i)); // logistic_lcdf(y[i] | log(N - i), 1) - log_xi = log_cum_prod + log_zi; - x[i] = exp(log_xi); - log_cum_prod += log1m_exp(log_zi); - target += log_xi; - } - x[N] = exp(log_cum_prod); - target += log_cum_prod; - return x; - } -} -data { - int N; - vector[N] alpha; -} -parameters { - vector[N - 1] y; -} -transformed parameters { - simplex[N] x = stickbreaking_logistic_simplex_constrain_lp(y); -} -model { - target += target_density_lp(x, alpha); -} diff --git a/transforms/simplex/StickbreakingNormal.stan b/transforms/simplex/StickbreakingNormal.stan deleted file mode 100644 index bc32c03..0000000 --- a/transforms/simplex/StickbreakingNormal.stan +++ /dev/null @@ -1,31 +0,0 @@ -functions { - vector stickbreaking_normal_simplex_constrain_lp(vector y) { - int N = rows(y) + 1; - vector[N] x; - real log_zi, log_xi, wi; - real log_cum_prod = 0; - for (i in 1:N - 1) { - wi = y[i] - log(N - i) / 2; - log_zi = std_normal_lcdf(wi); - log_xi = log_cum_prod + log_zi; - x[i] = exp(log_xi); - target += std_normal_lpdf(wi) + log_cum_prod; - log_cum_prod += log1m_exp(log_zi); - } - x[N] = exp(log_cum_prod); - return x; - } -} -data { - int N; - vector[N] alpha; -} -parameters { - vector[N - 1] y; -} -transformed parameters { - simplex[N] x = stickbreaking_normal_simplex_constrain_lp(y); -} -model { - target += target_density_lp(x, alpha); -} diff --git a/transforms/simplex/StickbreakingPowerLogistic.stan b/transforms/simplex/StickbreakingPowerLogistic.stan deleted file mode 100644 index 822cf66..0000000 --- a/transforms/simplex/StickbreakingPowerLogistic.stan +++ /dev/null @@ -1,32 +0,0 @@ -functions { - vector stickbreaking_power_logistic_simplex_constrain_lp(vector y) { - int N = rows(y) + 1; - vector[N] x; - real log_u, log_w, log_z; - real log_cum_prod = 0; - for (i in 1:(N-1)) { - log_u = log_inv_logit(y[i]); // logistic_lcdf(y[i] | 0, 1); - log_w = log_u / (N - i); - log_z = log1m_exp(log_w); - x[i] = exp(log_cum_prod + log_z); - target += 2 * log_u - y[i]; // logistic_lupdf(y[i] | 0, 1); - log_cum_prod += log1m_exp(log_z); - } - x[N] = exp(log_cum_prod); - target += -lgamma(N); - return x; - } -} -data { - int N; - vector[N] alpha; -} -parameters { - vector[N - 1] y; -} -transformed parameters { - simplex[N] x = stickbreaking_power_logistic_simplex_constrain_lp(y); -} -model { - target += target_density_lp(x, alpha); -} diff --git a/transforms/simplex/StickbreakingPowerNormal.stan b/transforms/simplex/StickbreakingPowerNormal.stan deleted file mode 100644 index 668ed14..0000000 --- a/transforms/simplex/StickbreakingPowerNormal.stan +++ /dev/null @@ -1,32 +0,0 @@ -functions { - vector stickbreaking_power_normal_simplex_constrain_lp(vector y) { - int N = rows(y) + 1; - vector[N] x; - real log_u, log_w, log_z; - real log_cum_prod = 0; - for (i in 1:(N-1)) { - log_u = std_normal_lcdf(y[i] |); - log_w = log_u / (N - i); - log_z = log1m_exp(log_w); - x[i] = exp(log_cum_prod + log_z); - target += std_normal_lpdf(y[i] |); - log_cum_prod += log1m_exp(log_z); - } - x[N] = exp(log_cum_prod); - target += -lgamma(N); - return x; - } -} -data { - int N; - vector[N] alpha; -} -parameters { - vector[N - 1] y; -} -transformed parameters { - simplex[N] x = stickbreaking_power_normal_simplex_constrain_lp(y); -} -model { - target += target_density_lp(x, alpha); -}