Skip to content

Commit

Permalink
Add itr argument to fit method and minor adds (#9)
Browse files Browse the repository at this point in the history
* Add itr arguments to fit method 
* Add links to papers in README
* Add project information in pyproject
  • Loading branch information
dirmeier authored Nov 19, 2022
1 parent 5212d16 commit 0a5c2ee
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 8 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@

Reconcile implements probabilistic time series forecast reconciliation methods introduced in

1) Zambon, Lorenzo, Dario Azzimonti, and Giorgio Corani. "Probabilistic reconciliation of forecasts via importance sampling." arXiv preprint arXiv:2210.02286 (2022).
2) Panagiotelis, Anastasios, et al. "Probabilistic forecast reconciliation: Properties, evaluation and score optimisation." European Journal of Operational Research (2022).
1) Zambon, Lorenzo, Dario Azzimonti, and Giorgio Corani. ["Probabilistic reconciliation of forecasts via importance sampling."](https://doi.org/10.48550/arXiv.2210.02286) arXiv preprint arXiv:2210.02286 (2022).
2) Panagiotelis, Anastasios, et al. ["Probabilistic forecast reconciliation: Properties, evaluation and score optimisation."](https://doi.org/10.1016/j.ejor.2022.07.040) European Journal of Operational Research (2022).

The package implements

Expand Down
6 changes: 3 additions & 3 deletions examples/reconciliation.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,9 @@ def run():

recon = ProbabilisticReconciliation(grouping, forecaster)
# do reconciliation via sampling
# _ = recon.sample_reconciled_posterior_predictive(
# random.PRNGKey(1), all_features, n_iter=100, n_warmup=50
# )
_ = recon.sample_reconciled_posterior_predictive(
random.PRNGKey(1), all_features, n_iter=100, n_warmup=50
)
# do reconciliation via optimization of the energy score
_ = recon.fit_reconciled_posterior_predictive(
random.PRNGKey(1), all_features, n_samples=100
Expand Down
13 changes: 12 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,20 @@
[tool.poetry]
name = "probabilistic-reconciliation"
version = "0.0.2"
version = "0.0.3"
description = "Probabilistic reconciliation of time series forecasts"
authors = ["Simon Dirmeier <[email protected]>"]
readme = "README.md"
license = "Apache-2.0"
homepage = "https://github.com/dirmeier/reconcile"
keywords = ["probabilistic reconciliation", "forecasting", "timeseries", "hierarchical time series"]
classifiers=[
"Development Status :: 1 - Planning",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
]
packages = [{include = "reconcile"}]


Expand Down
12 changes: 10 additions & 2 deletions reconcile/probabilistic_reconciliation.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def fit_reconciled_posterior_predictive(
xs_test: Array,
n_samples=2000,
net: Callable = None,
n_iter: int = None,
):
"""
Probabilistic reconciliation using energy score optimization
Expand All @@ -147,7 +148,10 @@ def fit_reconciled_posterior_predictive(
n_samples: int
number of samples to return
net: Callable
a flax neural network that is used for the projection
a flax neural network that is used for the projection or None to use
the linear projection from [1]
n_iter: int
number of iterations to train the network or None for early stopping
Returns
-------
Expand Down Expand Up @@ -221,6 +225,7 @@ def loss_fn(params):

batch_size = 64
early_stop = EarlyStopping(min_delta=0.1, patience=5)
itr = 0
while True:
sample_key, epoch_key, rng_key = random.split(rng_key, 3)
y_predictive_batch = predictive.sample(
Expand All @@ -230,9 +235,12 @@ def loss_fn(params):
state, loss = _step(state, epoch_key, ys, y_predictive_batch)
logger.info("Loss after batch update %d", loss)
_, early_stop = early_stop.update(loss)
if early_stop.should_stop:
if early_stop.should_stop and n_iter is None:
logger.info("Met early stopping criteria, breaking...")
break
elif n_iter is not None and itr == n_iter:
break
itr += 1

predictive = self._forecaster.posterior_predictive(rng_key, xs_test)
y_predictive = predictive.sample(
Expand Down

0 comments on commit 0a5c2ee

Please sign in to comment.