Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove nonfunctioning GPJax #15

Merged
merged 3 commits into from
Dec 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,6 @@ jobs:
- name: Install dependencies
run: |
pip install hatch
- name: Build package
run: |
pip install jaxlib jax
- name: Run tests
run: |
hatch run test:test
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/release.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.9]
python-version: [3.11]
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}
- name: Install pypa/build
Expand Down
126 changes: 61 additions & 65 deletions examples/reconciliation.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,38 @@
import chex
import distrax
import gpjax as gpx
import jax
import numpy as np
import optax
import pandas as pd
from einops import rearrange
from jax import Array
from jax import numpy as jnp
from jax import random as jr
from ramsey import NP, train_neural_process
from ramsey.nn import MLP
from statsmodels.tsa.arima_process import arma_generate_sample
from tensorflow_probability.substrates.jax import distributions as tfd

from reconcile.forecast import Forecaster
from reconcile.grouping import Grouping
from reconcile.probabilistic_reconciliation import ProbabilisticReconciliation

jax.config.update("jax_enable_x64", True)


class GPForecaster(Forecaster):
"""Example implementation of a forecaster"""
class NeuralProcessForecaster(Forecaster):
"""Example implementation of a forecaster."""

def __init__(self):
super().__init__()
self._models: list = []
self._xs: jax.Array = None
self._ys: jax.Array = None
self._xs: jax.Array
self._ys: jax.Array

@property
def data(self):
"""Returns the data"""
return self._ys, self._xs
def data(self) -> tuple[Array, Array]:
return self._xs, self._ys

def fit(
self, rng_key: jr.PRNGKey, ys: jax.Array, xs: jax.Array, niter=2000
):
"""Fit a model to each of the time series"""

"""Fit a model to each of the time series."""
self._xs = xs
self._ys = ys
chex.assert_rank([ys, xs], [3, 3])
Expand All @@ -43,70 +41,68 @@ def fit(
p = xs.shape[1]
self._models = [None] * p
for i in np.arange(p):
x, y = xs[:, [i], :], ys[:, [i], :]
x, y = xs[..., i, :], ys[..., i, :]
# fit a model for each time series
opt_posterior, _, D = self._fit_one(rng_key, x, y, niter)
model, params = self._fit_one(rng_key, x, y, niter)
# save the learned parameters and the original data
self._models[i] = opt_posterior, D
self._models[i] = model, params

def _fit_one(self, rng_key, x, y, niter):
# here we use GPs to model the time series
D = gpx.Dataset(X=x.reshape(-1, 1), y=y.reshape(-1, 1))
elbo, q, likelihood = self._model(rng_key, D.n)

negative_elbo = jax.jit(elbo)
optimiser = optax.adam(learning_rate=5e-3)
opt_posterior, history = gpx.fit(
model=q,
objective=negative_elbo,
train_data=D,
optim=optimiser,
num_iters=niter,
key=rng_key,
# here we use neural processes to model the time series
model = self._model()
n_context, n_target = 10, 20
params, _ = train_neural_process(
rng_key,
model,
x=x.reshape(1, -1, 1),
y=y.reshape(1, -1, 1),
n_context=n_context,
n_target=n_target,
n_iter=1000,
batch_size=1,
)
return opt_posterior, history, D
return model, params

@staticmethod
def _model(rng_key, n):
z = jr.uniform(rng_key, (20, 1))
prior = gpx.gps.Prior(
mean_function=gpx.mean_functions.Constant(),
kernel=gpx.kernels.RBF(),
)
likelihood = gpx.likelihoods.Gaussian(num_datapoints=n)
posterior = prior * likelihood
q = gpx.variational_families.CollapsedVariationalGaussian(
posterior=posterior,
inducing_inputs=z,
)
elbo = gpx.objectives.CollapsedELBO(negative=True)
return elbo, q, likelihood
def _model():
def get_neural_process():
dim = 128
np = NP(
decoder=MLP([dim] * 3 + [2]),
latent_encoder=(MLP([dim] * 3), MLP([dim, dim * 2])),
)
return np

neural_process = get_neural_process()
return neural_process

def posterior_predictive(self, rng_key, xs_test: jax.Array):
"""Compute the joint posterior predictive distribution at xs_test."""
chex.assert_rank(xs_test, 3)

q = xs_test.shape[1]
means = [None] * q
covs = [None] * q
scales = [None] * q
for i in np.arange(q):
x_test = xs_test[:, [i], :].reshape(-1, 1)
opt_posterior, D = self._models[i]
_, q, likelihood = self._model(rng_key, D.n)
latent_dist = opt_posterior(x_test, train_data=D)
predictive_dist = opt_posterior.posterior.likelihood(latent_dist)
means[i] = predictive_dist.mean()
cov = predictive_dist.scale_tril
covs[i] = cov.reshape((1, *cov.shape))

# here we stack the means and covariance functions of all
# GP models we used
means = jnp.vstack(means)
covs = jnp.vstack(covs)

# here we use a single distrax distribution to model the predictive
x_context = self._xs[..., i, :]
y_context = self._ys[..., i, :]
x_test = xs_test[..., i, :]

model, params = self._models[i]
predictive_dist = model.apply(
variables=params,
rngs={"sample": rng_key},
x_context=x_context.reshape(1, -1, 1),
y_context=y_context.reshape(1, -1, 1),
x_target=x_test.reshape(1, -1, 1),
)
means[i] = predictive_dist.mean
scales[i] = predictive_dist.scale

means = rearrange(jnp.vstack(means), "b t ... -> ... b t")
scales = rearrange(jnp.vstack(scales), "b t ... -> ... b t")
# posterior of _all_ models
posterior_predictive = distrax.MultivariateNormalTri(means, covs)
posterior_predictive = tfd.MultivariateNormalDiag(means, scales)
return posterior_predictive

def predictive_posterior_probability(
Expand Down Expand Up @@ -147,12 +143,12 @@ def sample_hierarchical_timeseries():
and the second one is a pd.DataFrame of groups
"""

def _group_names():
def _hierarchy():
hierarchy = ["A:10", "A:20", "B:10", "B:20", "B:30"]

return pd.DataFrame.from_dict({"h1": hierarchy})

return _sample_timeseries(100, 5), _group_names()
return _sample_timeseries(100, 5), _hierarchy()


def run():
Expand All @@ -161,7 +157,7 @@ def run():
all_timeseries = grouping.all_timeseries(b)
all_features = jnp.tile(x, [1, all_timeseries.shape[1], 1])

forecaster = GPForecaster()
forecaster = NeuralProcessForecaster()
forecaster.fit(
jr.PRNGKey(1),
all_timeseries[:, :, :90],
Expand Down
44 changes: 20 additions & 24 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,42 +7,40 @@ name = "probabilistic-reconciliation"
description = "Probabilistic reconciliation of time series forecasts"
authors = [{name = "Simon Dirmeier", email = "[email protected]"}]
readme = "README.md"
license = "Apache-2.0"
homepage = "https://github.com/dirmeier/reconcile"
license = {file = "LICENSE"}
keywords = ["probabilistic reconciliation", "forecasting", "timeseries", "hierarchical time series"]
classifiers=[
classifiers= [
"Development Status :: 3 - Alpha",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
]
requires-python = ">=3.9"
requires-python = ">=3.11"
dependencies = [
"blackjax-nightly>=0.9.6.post127",
"distrax>=0.1.2",
"chex>=0.1.5",
"jaxlib>=0.4.18",
"jax>=0.4.18",
"flax>=0.7.3",
"gpjax>=0.6.9",
"optax>=0.1.3",
"pandas>=1.5.1"
"blackjax>=1.2.4",
"chex>=0.1.8",
"einops>=0.8.0",
"flax>=0.10.2",
"jax>=0.4.38",
"optax>=0.2.4",
"pandas>=1.5.1",
"ramsey>=0.2.1",
"tfp-nightly[jax]>=0.26.0.dev20241227",
]
dynamic = ["version"]

[project.urls]
homepage = "https://github.com/dirmeier/reconcile"

[tool.hatch.build.targets.wheel]
packages = ["reconcile"]

[tool.hatch.version]
path = "reconcile/__init__.py"

[tool.hatch.build.targets.wheel]
packages = ["reconcile"]

[tool.hatch.build.targets.sdist]
exclude = [
"/.github",
Expand All @@ -55,20 +53,18 @@ dependencies = [
"ruff>=0.3.0",
"pytest>=7.2.0",
"pytest-cov>=4.0.0",
"gpjax>=0.5.0",
"statsmodels>=0.13.2"
]

[tool.hatch.envs.test.scripts]
lint = 'ruff check reconcile examples'
test = 'pytest -v --doctest-modules --cov=./reconcile --cov-report=xml reconcile'

[tool.hatch.envs.examples]
dependencies = [
"gpjax>=0.5.0",
"statsmodels>=0.13.2"
]

[tool.hatch.envs.test.scripts]
lint = 'ruff check reconcile examples'
test = 'pytest -v --doctest-modules --cov=./reconcile --cov-report=xml reconcile'

[tool.hatch.envs.examples.scripts]
reconciliation = 'python examples/reconciliation.py'

Expand Down
2 changes: 1 addition & 1 deletion reconcile/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""reconcile: Probabilistic reconciliation of time series forecasts."""

__version__ = "0.1.0"
__version__ = "0.2.0"

from reconcile.forecast import Forecaster
from reconcile.grouping import Grouping
Expand Down
8 changes: 4 additions & 4 deletions reconcile/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from jax import numpy as jnp
from jax import random

from examples.reconciliation import GPForecaster
from examples.reconciliation import NeuralProcessForecaster
from reconcile import ProbabilisticReconciliation
from reconcile.grouping import Grouping

Expand Down Expand Up @@ -40,11 +40,11 @@ def reconciliator():
all_timeseries = grouping.all_timeseries(b)
all_features = jnp.tile(x, [1, all_timeseries.shape[1], 1])

forecaster = GPForecaster()
forecaster = NeuralProcessForecaster()
forecaster.fit(
random.PRNGKey(1),
all_timeseries[:, :90, :],
all_features[:, :90, :],
all_timeseries[:, :, :90],
all_features[:, :, :90],
100,
)

Expand Down
6 changes: 3 additions & 3 deletions reconcile/forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

import abc

import distrax
from jax import Array
from jax import random as jr
from tensorflow_probability.substrates.jax import distributions as tfp


class Forecaster(metaclass=abc.ABCMeta):
Expand Down Expand Up @@ -49,7 +49,7 @@ def fit(self, rng_key: jr.PRNGKey, ys: Array, xs: Array) -> None:
@abc.abstractmethod
def posterior_predictive(
self, rng_key: jr.PRNGKey, xs_test: Array
) -> distrax.Distribution:
) -> tfp.Distribution:
"""Computes the posterior predictive distribution at some input points.

Args:
Expand All @@ -61,7 +61,7 @@ def posterior_predictive(
elements as the original training data

Return:
returns a distrax Distribution with batch shape (,P) and event
returns a TFP Distribution with batch shape (,P) and event
shape (,M), such that a single sample has shape (P, M) and
multiple samples have shape (S, P, M)
"""
Expand Down
4 changes: 2 additions & 2 deletions reconcile/grouping.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,12 @@ def summing_matrix(self):

def extract_bottom_timeseries(self, y):
"""Getter for the bottom time series."""
return y[:, self.n_upper_timeseries :, :]
return y[..., self.n_upper_timeseries :, :]

def upper_time_series(self, b):
"""Getter for upper time series."""
y = self.all_timeseries(b)
return y[:, : self.n_upper_timeseries, :]
return y[..., : self.n_upper_timeseries, :]

@staticmethod
def _paste0(a, b):
Expand Down
Loading
Loading