diff --git a/benchmark/examples/bart_examples.py b/benchmark/examples/bart_examples.py index a8c8663..553f13f 100644 --- a/benchmark/examples/bart_examples.py +++ b/benchmark/examples/bart_examples.py @@ -77,8 +77,7 @@ def test_coal(args): "mu", X=x_data, Y=np.log(y_data), - m=args.trees, - # split_rules=["ContinuousSplit"] + m=args.trees ) exp_mu = pm.Deterministic("exp_mu", pm.math.exp(mu)) y_pred = pm.Poisson("y_pred", mu=exp_mu, observed=y_data) @@ -92,7 +91,7 @@ def test_coal(args): ], random_seed=RANDOM_SEED, ) - # step = pmb.PGBART([mu], batch=tuple(args.batch), num_particles=args.particles) + # step = pmb.PGBART([mu], batch=tuple(args.batch), num_particles=args.particles) # for i in range(1500): # sum_trees, stats = step.astep(i) @@ -110,6 +109,25 @@ def test_coal(args): ax.set_ylabel("rate") plt.show() +def test_asymmetric_laplace(args): + bmi = pd.read_csv(pm.get_data("bmi.csv")) + + y = bmi.bmi.values + X = bmi.age.values[:, None] + y_stack = np.stack([bmi.bmi.values] * 3) + quantiles = np.array([[0.1, 0.5, 0.9]]).T + + coords = { + "quantiles": quantiles.flatten(), + "n_obs": np.arange(X.shape[0]) + } + + with pm.Model(coords=coords) as model: + mu = pmb.BART("mu", X, y, m=5, dims=["quantiles", "n_obs"]) + sigma = pm.HalfNormal("sigma", 5) + obs = pm.AsymmetricLaplace("obs", mu=mu, b=sigma, q=quantiles, observed=y_stack) + step = pmb.PGBART([mu], num_particles=3, batch=(0.1, 0.1)) + def main(args): @@ -119,6 +137,8 @@ def main(args): test_bikes(args) elif args.model == "propensity": test_propensity(args) + elif args.model == "asymmetric": + test_asymmetric_laplace(args) else: raise TypeError("Invalid model argument passed.") diff --git a/python/pymc_bart/bart.py b/python/pymc_bart/bart.py index ac6cf8b..834fdca 100644 --- a/python/pymc_bart/bart.py +++ b/python/pymc_bart/bart.py @@ -25,6 +25,9 @@ from pymc.distributions.distribution import Distribution, _support_point from pymc.logprob.abstract import _logprob from pytensor.tensor.random.op import RandomVariable +from pytensor.tensor.sharedvar import TensorSharedVariable + +from .utils import _sample_posterior __all__ = ["BART"] @@ -32,32 +35,39 @@ class BARTRV(RandomVariable): """Base class for BART.""" name: str = "BART" - signature = "(m,n),(m),(),(),(),(k)->(m)" - # ndim_supp = 1 - ndims_params: List[int] = [2, 1, 0, 0, 0, 1] + signature = "(m,n),(m),(),(),() -> (m)" dtype: str = "floatX" _print_name: Tuple[str, str] = ("BART", "\\operatorname{BART}") all_trees = None def _supp_shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None): # pylint: disable=arguments-renamed - return dist_params[0].shape[:1] + idx = dist_params[0].ndim - 2 + return [dist_params[0].shape[idx]] @classmethod def rng_fn( # pylint: disable=W0237 - cls, rng=None, X=None, Y=None, m=None, alpha=None, beta=None, split_prior=None, size=None + cls, rng=None, X=None, Y=None, m=None, alpha=None, beta=None, size=None ): + if not size: + size = None + + if isinstance(cls.Y, TensorSharedVariable): + Y = cls.Y.eval() + else: + Y = cls.Y + if not cls.all_trees: if size is not None: - return np.full((size[0], cls.Y.shape[0]), cls.Y.mean()) + return np.full((size[0], Y.shape[0]), Y.mean()) else: - return np.full(cls.Y.shape[0], cls.Y.mean()) - # TODO: !!! - # else: - # if size is not None: - # shape = size[0] - # else: - # shape = 1 - # return _sample_posterior(cls.all_trees, cls.X, rng=rng, shape=shape).squeeze().T + return np.full(Y.shape[0], Y.mean()) + else: + if size is not None: + shape = size[0] + else: + shape = 1 + raise NotImplementedError("_sample_posterior not implemented") + # return _sample_posterior(cls.all_trees, cls.X, rng=rng, shape=shape).squeeze().T bart = BARTRV() @@ -175,7 +185,7 @@ def get_moment(rv, size, *rv_inputs): return cls.get_moment(rv, size, *rv_inputs) cls.rv_op = bart_op - params = [X, Y, m, alpha, beta, split_prior] + params = [X, Y, m, alpha, beta] return super().__new__(cls, name, *params, **kwargs) @classmethod @@ -208,6 +218,16 @@ def preprocess_xy(X, Y) -> Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_]] if isinstance(X, (Series, DataFrame)): X = X.to_numpy() + try: + import polars as pl + + if isinstance(X, (pl.Series, pl.DataFrame)): + X = X.to_numpy() + if isinstance(Y, (pl.Series, pl.DataFrame)): + Y = Y.to_numpy() + except ImportError: + pass + Y = Y.astype(float) X = X.astype(float) diff --git a/python/pymc_bart/pgbart.py b/python/pymc_bart/pgbart.py index 18fa94d..34a19f7 100644 --- a/python/pymc_bart/pgbart.py +++ b/python/pymc_bart/pgbart.py @@ -95,12 +95,16 @@ def __init__( # noqa: PLR0915 shape = initial_point[value_bart.name].shape self.shape = 1 if len(shape) == 1 else shape[0] + print(f"shape: {shape}, self.shape: {self.shape}") + # Set trees_shape (dim for separate tree structures) # and leaves_shape (dim for leaf node values) # One of the two is always one, the other equal to self.shape self.trees_shape = self.shape if self.bart.separate_trees else 1 self.leaves_shape = self.shape if not self.bart.separate_trees else 1 + print(f"self.leaves_shape: {self.leaves_shape}") + if self.bart.split_prior.size == 0: self.alpha_vec = np.ones(self.X.shape[1]) else: @@ -112,7 +116,8 @@ def __init__( # noqa: PLR0915 self.split_rules = ["ContinuousSplit"] * self.X.shape[1] # If data is binary - self.leaf_sd = np.ones((self.trees_shape, self.leaves_shape)) + # self.leaf_sd = np.ones((self.trees_shape, self.leaves_shape)) + self.leaf_sd = np.ones(self.leaves_shape) y_unique = np.unique(self.bart.Y) if y_unique.size == 2 and np.all(y_unique == [0, 1]): @@ -120,11 +125,13 @@ def __init__( # noqa: PLR0915 else: self.leaf_sd *= self.bart.Y.std() / self.m ** 0.5 + print(f"self.leaf_std: {self.leaf_sd}") + # Compile the PyMC model to create a C callback. This function pointer is # passed to Rust and called using Rust's foreign function interface (FFI) self.compiled_pymc_model = CompiledPyMCModel(model, vars) - # Initialize the Rust Particle-Gibbs sampler + # Initialize the Rust Particle-Gibbs sampler state self.state = initialize( X=self.X, y=self.bart.Y, @@ -138,7 +145,7 @@ def __init__( # noqa: PLR0915 n_particles=num_particles, leaf_sd=self.leaf_sd, batch=batch, - _leaves_shape=self.leaves_shape, + leaves_shape=self.leaves_shape, ) self.tune = True diff --git a/python/pymc_bart/utils.py b/python/pymc_bart/utils.py index 33fc202..8fe9889 100644 --- a/python/pymc_bart/utils.py +++ b/python/pymc_bart/utils.py @@ -14,13 +14,11 @@ from scipy.signal import savgol_filter from scipy.stats import norm -# from .tree import Tree TensorLike = Union[npt.NDArray[np.float64], pt.TensorVariable] def _sample_posterior( - # all_trees: List[List[Tree]], all_trees, X: TensorLike, rng: np.random.Generator, @@ -50,7 +48,7 @@ def _sample_posterior( X = X.eval() if size is None: - size_iter: Union[List, Tuple] = (1,) + size_iter: Union[list, tuple] = (1,) elif isinstance(size, int): size_iter = [size] else: @@ -60,9 +58,6 @@ def _sample_posterior( for s in size_iter: flatten_size *= s - print(f"len(stacked_trees): {len(stacked_trees)}") - print(f"flatten_size: {flatten_size}") - idx = rng.integers(0, len(stacked_trees), size=flatten_size) trees_shape = len(stacked_trees[0]) diff --git a/src/lib.rs b/src/lib.rs index 3688138..7998fd0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -63,9 +63,9 @@ fn initialize( response: String, n_trees: usize, n_particles: usize, - leaf_sd: f64, + leaf_sd: Vec, batch: (f64, f64), - _leaves_shape: usize, + leaves_shape: usize, ) -> PyResult { // Heap allocation because size of 'ExternalData' is not known at compile time let data = Box::new(ExternalData::new(X, y, logp)); @@ -96,6 +96,7 @@ fn initialize( split_prior.to_vec().unwrap(), response, rules, + leaves_shape, ); let state = PgBartState::new(params, data); diff --git a/src/math.rs b/src/math.rs index d7506b0..60f5ffb 100644 --- a/src/math.rs +++ b/src/math.rs @@ -19,7 +19,7 @@ impl RunningStd { } /// Update the running statistics with a new value - pub fn update(&mut self, new_value: &[f64]) -> f64 { + pub fn update(&mut self, new_value: &[f64]) -> Vec { self.count += 1; let (mean, mean_2, std) = update_stats(self.count, &self.mean, &self.mean_2, new_value); self.mean = mean; @@ -61,9 +61,9 @@ fn update_stats( } /// Calculate the mean of the array -fn compute_mean(ari: &[f64]) -> f64 { +fn compute_mean(ari: &[f64]) -> Vec { let sum: f64 = ari.iter().sum(); - sum / ari.len() as f64 + vec![sum / ari.len() as f64] } /// Computes the normalized cumulative sum. diff --git a/src/ops.rs b/src/ops.rs index d20e73d..6049e41 100644 --- a/src/ops.rs +++ b/src/ops.rs @@ -121,12 +121,17 @@ impl TreeSamplingOps { mu: &[f64], _obs: &[f64], m: usize, - leaf_sd: &f64, - _shape: usize, + leaf_sd: &Vec, + _shape: &usize, response: &Response, ) -> f64 { let mut rng = thread_rng(); - let norm = self.normal.sample(&mut rng) * leaf_sd; + + if leaf_sd.len() > 1 { + todo!("Multiple `leaf_sd` not supported.") + } + + let norm = self.normal.sample(&mut rng) * leaf_sd[0]; match mu.len() { 0 => 0.0, diff --git a/src/particle.rs b/src/particle.rs index 63172b7..a124933 100644 --- a/src/particle.rs +++ b/src/particle.rs @@ -200,7 +200,7 @@ impl Particle { &observations.collect::>(), state.params.n_trees, &state.params.leaf_sd, - 1, // shape + &state.params.n_dim, &state.params.response, ) }; @@ -213,7 +213,7 @@ impl Particle { &observations.collect::>(), state.params.n_trees, &state.params.leaf_sd, - 1, // shape + &state.params.n_dim, &state.params.response, ) }; diff --git a/src/pgbart.rs b/src/pgbart.rs index 3852852..4e30248 100644 --- a/src/pgbart.rs +++ b/src/pgbart.rs @@ -34,7 +34,7 @@ pub struct PgBartSettings { /// beta parameter to control node depth. pub beta: f64, /// Leaf node standard deviation. - pub leaf_sd: f64, + pub leaf_sd: Vec, /// Batch size to use during tuning and draws. pub batch: (f64, f64), /// Initial prior probability over feature splitting probability. @@ -43,6 +43,8 @@ pub struct PgBartSettings { pub response: Response, /// Split rule strategy to use for sampling threshold (split) values. pub split_rules: Vec, + /// Number of dimensions for multi-output leaf values + pub n_dim: usize, } impl PgBartSettings { @@ -54,11 +56,12 @@ impl PgBartSettings { n_particles: usize, alpha: f64, beta: f64, - leaf_sd: f64, + leaf_sd: Vec, batch: (f64, f64), init_alpha_vec: Vec, response: Response, split_rules: Vec, + n_dim: usize, ) -> Self { Self { n_trees, @@ -70,6 +73,7 @@ impl PgBartSettings { init_alpha_vec, response, split_rules, + n_dim, } } } @@ -126,7 +130,6 @@ impl PgBartState { alpha: params.alpha, beta: params.beta, normal: Normal::new(0.0, 1.0).unwrap(), - // uniform: Uniform::new(0.0, 1.0), }; Self { @@ -220,7 +223,7 @@ impl PgBartState { self.update_splitting_probability(&new_particle); } - if self.iter > 2 { + if self.iter > 2 && self.params.leaf_sd.len() <= 1 { self.params.leaf_sd = self.tuning_stats.update(&new_particle_preds.to_vec()); } else { // Update tuning statistics without assigning a new leaf standard deviation @@ -310,6 +313,33 @@ pub fn resample_particles(particles: &mut Vec, weights: &[f64]) -> Vec // Move the first particle without cloning resampled_particles.push(particles[0].clone()); + // Pre-allocate index counts array instead of using HashMap + // Add 1 since we're using 1-based indexing for the rest of the particles + let mut index_counts = vec![0usize; num_particles]; + + // Generate systematic resampling indices and count occurrences + // Using a dedicated counter array instead of HashMap + let mut rng = thread_rng(); + let u = rng.gen::() / (num_particles - 1) as f64; + + let mut cumsum = 0.0; + let mut j = 1; // Start from 1 since we already handled particle 0 + + // Single pass to compute resampling indices + for i in 0..(num_particles - 1) { + let target = u + i as f64 / (num_particles - 1) as f64; + + while j < weights.len() && cumsum + weights[j] < target { + cumsum += weights[j]; + j += 1; + } + + // Increment count for this index + index_counts[j] += 1; + } + + println!("index_counts: {:?}", index_counts); + // Resample Particle indices and count number of occurences each index appears let mut index_counts = systematic_resample(weights, num_particles - 1) .map(|idx| idx + 1) @@ -318,6 +348,8 @@ pub fn resample_particles(particles: &mut Vec, weights: &[f64]) -> Vec acc }); + println!("index_counts: {:?}", index_counts); + // Stage 1: Process particles that need cloning, i.e. index count > 1 let mut to_remove = Vec::new(); for (&idx, &count) in &index_counts {