Skip to content

Commit

Permalink
Merge pull request #44 from GStechschulte/multioutput
Browse files Browse the repository at this point in the history
add initial support for multi-output leave nodes
  • Loading branch information
GStechschulte authored Jan 10, 2025
2 parents 419fefc + f7bb240 commit 828b941
Show file tree
Hide file tree
Showing 9 changed files with 121 additions and 41 deletions.
26 changes: 23 additions & 3 deletions benchmark/examples/bart_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,7 @@ def test_coal(args):
"mu",
X=x_data,
Y=np.log(y_data),
m=args.trees,
# split_rules=["ContinuousSplit"]
m=args.trees
)
exp_mu = pm.Deterministic("exp_mu", pm.math.exp(mu))
y_pred = pm.Poisson("y_pred", mu=exp_mu, observed=y_data)
Expand All @@ -92,7 +91,7 @@ def test_coal(args):
],
random_seed=RANDOM_SEED,
)
# step = pmb.PGBART([mu], batch=tuple(args.batch), num_particles=args.particles)
# step = pmb.PGBART([mu], batch=tuple(args.batch), num_particles=args.particles)

# for i in range(1500):
# sum_trees, stats = step.astep(i)
Expand All @@ -110,6 +109,25 @@ def test_coal(args):
ax.set_ylabel("rate")
plt.show()

def test_asymmetric_laplace(args):
bmi = pd.read_csv(pm.get_data("bmi.csv"))

y = bmi.bmi.values
X = bmi.age.values[:, None]
y_stack = np.stack([bmi.bmi.values] * 3)
quantiles = np.array([[0.1, 0.5, 0.9]]).T

coords = {
"quantiles": quantiles.flatten(),
"n_obs": np.arange(X.shape[0])
}

with pm.Model(coords=coords) as model:
mu = pmb.BART("mu", X, y, m=5, dims=["quantiles", "n_obs"])
sigma = pm.HalfNormal("sigma", 5)
obs = pm.AsymmetricLaplace("obs", mu=mu, b=sigma, q=quantiles, observed=y_stack)
step = pmb.PGBART([mu], num_particles=3, batch=(0.1, 0.1))


def main(args):

Expand All @@ -119,6 +137,8 @@ def main(args):
test_bikes(args)
elif args.model == "propensity":
test_propensity(args)
elif args.model == "asymmetric":
test_asymmetric_laplace(args)
else:
raise TypeError("Invalid model argument passed.")

Expand Down
50 changes: 35 additions & 15 deletions python/pymc_bart/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,39 +25,49 @@
from pymc.distributions.distribution import Distribution, _support_point
from pymc.logprob.abstract import _logprob
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.sharedvar import TensorSharedVariable

from .utils import _sample_posterior

__all__ = ["BART"]

class BARTRV(RandomVariable):
"""Base class for BART."""

name: str = "BART"
signature = "(m,n),(m),(),(),(),(k)->(m)"
# ndim_supp = 1
ndims_params: List[int] = [2, 1, 0, 0, 0, 1]
signature = "(m,n),(m),(),(),() -> (m)"
dtype: str = "floatX"
_print_name: Tuple[str, str] = ("BART", "\\operatorname{BART}")
all_trees = None

def _supp_shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None): # pylint: disable=arguments-renamed
return dist_params[0].shape[:1]
idx = dist_params[0].ndim - 2
return [dist_params[0].shape[idx]]

@classmethod
def rng_fn( # pylint: disable=W0237
cls, rng=None, X=None, Y=None, m=None, alpha=None, beta=None, split_prior=None, size=None
cls, rng=None, X=None, Y=None, m=None, alpha=None, beta=None, size=None
):
if not size:
size = None

if isinstance(cls.Y, TensorSharedVariable):
Y = cls.Y.eval()
else:
Y = cls.Y

if not cls.all_trees:
if size is not None:
return np.full((size[0], cls.Y.shape[0]), cls.Y.mean())
return np.full((size[0], Y.shape[0]), Y.mean())
else:
return np.full(cls.Y.shape[0], cls.Y.mean())
# TODO: !!!
# else:
# if size is not None:
# shape = size[0]
# else:
# shape = 1
# return _sample_posterior(cls.all_trees, cls.X, rng=rng, shape=shape).squeeze().T
return np.full(Y.shape[0], Y.mean())
else:
if size is not None:
shape = size[0]
else:
shape = 1
raise NotImplementedError("_sample_posterior not implemented")
# return _sample_posterior(cls.all_trees, cls.X, rng=rng, shape=shape).squeeze().T


bart = BARTRV()
Expand Down Expand Up @@ -175,7 +185,7 @@ def get_moment(rv, size, *rv_inputs):
return cls.get_moment(rv, size, *rv_inputs)

cls.rv_op = bart_op
params = [X, Y, m, alpha, beta, split_prior]
params = [X, Y, m, alpha, beta]
return super().__new__(cls, name, *params, **kwargs)

@classmethod
Expand Down Expand Up @@ -208,6 +218,16 @@ def preprocess_xy(X, Y) -> Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_]]
if isinstance(X, (Series, DataFrame)):
X = X.to_numpy()

try:
import polars as pl

if isinstance(X, (pl.Series, pl.DataFrame)):
X = X.to_numpy()
if isinstance(Y, (pl.Series, pl.DataFrame)):
Y = Y.to_numpy()
except ImportError:
pass

Y = Y.astype(float)
X = X.astype(float)

Expand Down
13 changes: 10 additions & 3 deletions python/pymc_bart/pgbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,16 @@ def __init__( # noqa: PLR0915
shape = initial_point[value_bart.name].shape
self.shape = 1 if len(shape) == 1 else shape[0]

print(f"shape: {shape}, self.shape: {self.shape}")

# Set trees_shape (dim for separate tree structures)
# and leaves_shape (dim for leaf node values)
# One of the two is always one, the other equal to self.shape
self.trees_shape = self.shape if self.bart.separate_trees else 1
self.leaves_shape = self.shape if not self.bart.separate_trees else 1

print(f"self.leaves_shape: {self.leaves_shape}")

if self.bart.split_prior.size == 0:
self.alpha_vec = np.ones(self.X.shape[1])
else:
Expand All @@ -112,19 +116,22 @@ def __init__( # noqa: PLR0915
self.split_rules = ["ContinuousSplit"] * self.X.shape[1]

# If data is binary
self.leaf_sd = np.ones((self.trees_shape, self.leaves_shape))
# self.leaf_sd = np.ones((self.trees_shape, self.leaves_shape))
self.leaf_sd = np.ones(self.leaves_shape)

y_unique = np.unique(self.bart.Y)
if y_unique.size == 2 and np.all(y_unique == [0, 1]):
self.leaf_sd *= 3 / self.m**0.5
else:
self.leaf_sd *= self.bart.Y.std() / self.m ** 0.5

print(f"self.leaf_std: {self.leaf_sd}")

# Compile the PyMC model to create a C callback. This function pointer is
# passed to Rust and called using Rust's foreign function interface (FFI)
self.compiled_pymc_model = CompiledPyMCModel(model, vars)

# Initialize the Rust Particle-Gibbs sampler
# Initialize the Rust Particle-Gibbs sampler state
self.state = initialize(
X=self.X,
y=self.bart.Y,
Expand All @@ -138,7 +145,7 @@ def __init__( # noqa: PLR0915
n_particles=num_particles,
leaf_sd=self.leaf_sd,
batch=batch,
_leaves_shape=self.leaves_shape,
leaves_shape=self.leaves_shape,
)

self.tune = True
Expand Down
7 changes: 1 addition & 6 deletions python/pymc_bart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,11 @@
from scipy.signal import savgol_filter
from scipy.stats import norm

# from .tree import Tree

TensorLike = Union[npt.NDArray[np.float64], pt.TensorVariable]


def _sample_posterior(
# all_trees: List[List[Tree]],
all_trees,
X: TensorLike,
rng: np.random.Generator,
Expand Down Expand Up @@ -50,7 +48,7 @@ def _sample_posterior(
X = X.eval()

if size is None:
size_iter: Union[List, Tuple] = (1,)
size_iter: Union[list, tuple] = (1,)
elif isinstance(size, int):
size_iter = [size]
else:
Expand All @@ -60,9 +58,6 @@ def _sample_posterior(
for s in size_iter:
flatten_size *= s

print(f"len(stacked_trees): {len(stacked_trees)}")
print(f"flatten_size: {flatten_size}")

idx = rng.integers(0, len(stacked_trees), size=flatten_size)

trees_shape = len(stacked_trees[0])
Expand Down
5 changes: 3 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,9 @@ fn initialize(
response: String,
n_trees: usize,
n_particles: usize,
leaf_sd: f64,
leaf_sd: Vec<f64>,
batch: (f64, f64),
_leaves_shape: usize,
leaves_shape: usize,
) -> PyResult<StateWrapper> {
// Heap allocation because size of 'ExternalData' is not known at compile time
let data = Box::new(ExternalData::new(X, y, logp));
Expand Down Expand Up @@ -96,6 +96,7 @@ fn initialize(
split_prior.to_vec().unwrap(),
response,
rules,
leaves_shape,
);
let state = PgBartState::new(params, data);

Expand Down
6 changes: 3 additions & 3 deletions src/math.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ impl RunningStd {
}

/// Update the running statistics with a new value
pub fn update(&mut self, new_value: &[f64]) -> f64 {
pub fn update(&mut self, new_value: &[f64]) -> Vec<f64> {
self.count += 1;
let (mean, mean_2, std) = update_stats(self.count, &self.mean, &self.mean_2, new_value);
self.mean = mean;
Expand Down Expand Up @@ -61,9 +61,9 @@ fn update_stats(
}

/// Calculate the mean of the array
fn compute_mean(ari: &[f64]) -> f64 {
fn compute_mean(ari: &[f64]) -> Vec<f64> {
let sum: f64 = ari.iter().sum();
sum / ari.len() as f64
vec![sum / ari.len() as f64]
}

/// Computes the normalized cumulative sum.
Expand Down
11 changes: 8 additions & 3 deletions src/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,12 +121,17 @@ impl TreeSamplingOps {
mu: &[f64],
_obs: &[f64],
m: usize,
leaf_sd: &f64,
_shape: usize,
leaf_sd: &Vec<f64>,
_shape: &usize,
response: &Response,
) -> f64 {
let mut rng = thread_rng();
let norm = self.normal.sample(&mut rng) * leaf_sd;

if leaf_sd.len() > 1 {
todo!("Multiple `leaf_sd` not supported.")
}

let norm = self.normal.sample(&mut rng) * leaf_sd[0];

match mu.len() {
0 => 0.0,
Expand Down
4 changes: 2 additions & 2 deletions src/particle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ impl Particle {
&observations.collect::<Vec<_>>(),
state.params.n_trees,
&state.params.leaf_sd,
1, // shape
&state.params.n_dim,
&state.params.response,
)
};
Expand All @@ -213,7 +213,7 @@ impl Particle {
&observations.collect::<Vec<_>>(),
state.params.n_trees,
&state.params.leaf_sd,
1, // shape
&state.params.n_dim,
&state.params.response,
)
};
Expand Down
Loading

0 comments on commit 828b941

Please sign in to comment.