Skip to content

Commit

Permalink
Imp/Add new model StatsForecastAutoTBATS (#2611)
Browse files Browse the repository at this point in the history
* Add new model StatsForecastAutoTBATS

* Update darts/models/forecasting/sf_auto_tbats.py

Co-authored-by: Dennis Bader <[email protected]>

* Update CHANGELOG

* Update README.md

* Update covariates.md

* Update test_probabilistic_models.py

* Update test_probabilistic_models.py

* update changelog and readme

---------

Co-authored-by: Dennis Bader <[email protected]>
  • Loading branch information
cnhwl and dennisbader authored Dec 24, 2024
1 parent 441a58a commit aad1440
Show file tree
Hide file tree
Showing 7 changed files with 114 additions and 0 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ but cannot always guarantee backwards compatibility. Changes that may **break co

**Improved**

- New model: `StatsForecastAutoTBATS`. This model offers the [AutoTBATS](https://nixtlaverse.nixtla.io/statsforecast/src/core/models.html#autotbats) model from Nixtla's `statsforecasts` library. [#2611](https://github.com/unit8co/darts/pull/2611) by [He Weilin](https://github.com/cnhwl).

**Fixed**

**Dependencies**
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ on bringing more models and features.
| [StatsforecastAutoETS](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.sf_auto_ets.html#darts.models.forecasting.sf_auto_ets.StatsForecastAutoETS) | [Nixtla's statsforecast](https://github.com/Nixtla/statsforecast) | ✅ 🔴 | 🔴 ✅ 🔴 | ✅ 🔴 | 🔴 |
| [StatsforecastAutoCES](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.sf_auto_ces.html#darts.models.forecasting.sf_auto_ces.StatsForecastAutoCES) | [Nixtla's statsforecast](https://github.com/Nixtla/statsforecast) | ✅ 🔴 | 🔴 🔴 🔴 | 🔴 🔴 | 🔴 |
| [BATS](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.tbats_model.html#darts.models.forecasting.tbats_model.BATS) and [TBATS](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.tbats_model.html#darts.models.forecasting.tbats_model.TBATS) | [TBATS paper](https://robjhyndman.com/papers/ComplexSeasonality.pdf) | ✅ 🔴 | 🔴 🔴 🔴 | ✅ 🔴 | 🔴 |
| [StatsForecastAutoTBATS](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.sf_auto_tbats.html#darts.models.forecasting.sf_auto_tbats.StatsForecastAutoTBATS) | [Nixtla's statsforecast](https://github.com/Nixtla/statsforecast) | ✅ 🔴 | 🔴 🔴 🔴 | ✅ 🔴 | 🔴 |
| [Theta](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.theta.html#darts.models.forecasting.theta.Theta) and [FourTheta](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.theta.html#darts.models.forecasting.theta.FourTheta) | [Theta](https://robjhyndman.com/papers/Theta.pdf) & [4 Theta](https://github.com/Mcompetitions/M4-methods/blob/master/4Theta%20method.R) | ✅ 🔴 | 🔴 🔴 🔴 | 🔴 🔴 | 🔴 |
| [StatsForecastAutoTheta](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.sf_auto_theta.html#darts.models.forecasting.sf_auto_theta.StatsForecastAutoTheta) | [Nixtla's statsforecast](https://github.com/Nixtla/statsforecast) | ✅ 🔴 | 🔴 🔴 🔴 | ✅ 🔴 | 🔴 |
| [Prophet](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.prophet_model.html#darts.models.forecasting.prophet_model.Prophet) | [Prophet repo](https://github.com/facebook/prophet) | ✅ 🔴 | 🔴 ✅ 🔴 | ✅ 🔴 | 🔴 |
Expand Down
3 changes: 3 additions & 0 deletions darts/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@
from darts.models.forecasting.sf_auto_arima import StatsForecastAutoARIMA
from darts.models.forecasting.sf_auto_ces import StatsForecastAutoCES
from darts.models.forecasting.sf_auto_ets import StatsForecastAutoETS
from darts.models.forecasting.sf_auto_tbats import StatsForecastAutoTBATS
from darts.models.forecasting.sf_auto_theta import StatsForecastAutoTheta

except ImportError:
Expand All @@ -108,6 +109,7 @@
StatsForecastAutoCES = NotImportedModule(module_name="StatsForecast", warn=False)
StatsForecastAutoETS = NotImportedModule(module_name="StatsForecast", warn=False)
StatsForecastAutoTheta = NotImportedModule(module_name="StatsForecast", warn=False)
StatsForecastAutoTBATS = NotImportedModule(module_name="StatsForecast", warn=False)

try:
from darts.models.forecasting.xgboost import XGBModel
Expand Down Expand Up @@ -160,6 +162,7 @@
"StatsForecastAutoCES",
"StatsForecastAutoETS",
"StatsForecastAutoTheta",
"StatsForecastAutoTBATS",
"XGBModel",
"GaussianProcessFilter",
"KalmanFilter",
Expand Down
1 change: 1 addition & 0 deletions darts/models/forecasting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
- :class:`~darts.models.forecasting.sf_auto_ces.StatsForecastAutoCES`
- :class:`~darts.models.forecasting.tbats_model.BATS`
- :class:`~darts.models.forecasting.tbats_model.TBATS`
- :class:`~darts.models.forecasting.sf_auto_tbats.StatsForecastAutoTBATS`
- :class:`~darts.models.forecasting.theta.Theta`
- :class:`~darts.models.forecasting.theta.FourTheta`
- :class:`~darts.models.forecasting.sf_auto_theta.StatsForecastAutoTheta`
Expand Down
104 changes: 104 additions & 0 deletions darts/models/forecasting/sf_auto_tbats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
"""
StatsForecastAutoTBATS
-----------
"""

from statsforecast.models import AutoTBATS as SFAutoTBATS

from darts import TimeSeries
from darts.models.components.statsforecast_utils import (
create_normal_samples,
one_sigma_rule,
unpack_sf_dict,
)
from darts.models.forecasting.forecasting_model import LocalForecastingModel


class StatsForecastAutoTBATS(LocalForecastingModel):
def __init__(self, *autoTBATS_args, **autoTBATS_kwargs):
"""Auto-TBATS based on `Statsforecasts package
<https://github.com/Nixtla/statsforecast>`_.
Automatically selects the best TBATS model from all feasible combinations of the parameters `use_boxcox`,
`use_trend`, `use_damped_trend`, and `use_arma_errors`. Selection is made using the AIC.
Default value for `use_arma_errors` is True since this enables the evaluation of models with
and without ARMA errors.
<https://citeseerx.ist.psu.edu/document?repid=rep1&type=pdf&doi=f3de25596ab60ef0e886366826bf58a02b35a44f>
<https://doi.org/10.4225/03/589299681de3d>
We refer to the `statsforecast AutoTBATS documentation
<https://nixtlaverse.nixtla.io/statsforecast/src/core/models.html#autotbats>`_
for the exhaustive documentation of the arguments.
Parameters
----------
autoTBATS_args
Positional arguments for ``statsforecasts.models.AutoTBATS``.
autoTBATS_kwargs
Keyword arguments for ``statsforecasts.models.AutoTBATS``.
Examples
--------
>>> from darts.datasets import AirPassengersDataset
>>> from darts.models import StatsForecastAutoTBATS
>>> series = AirPassengersDataset().load()
>>> # define StatsForecastAutoTBATS parameters
>>> model = StatsForecastAutoTBATS(season_length=12)
>>> model.fit(series)
>>> pred = model.predict(6)
>>> pred.values()
array([[450.79653684],
[472.09265790],
[497.76948306],
[510.74927369],
[520.92224557],
[570.33881522]])
"""
super().__init__()
self.model = SFAutoTBATS(*autoTBATS_args, **autoTBATS_kwargs)

def fit(self, series: TimeSeries):
super().fit(series)
self._assert_univariate(series)
series = self.training_series
self.model.fit(
series.values(copy=False).flatten(),
)
return self

def predict(
self,
n: int,
num_samples: int = 1,
verbose: bool = False,
show_warnings: bool = True,
):
super().predict(n, num_samples)
forecast_dict = self.model.predict(
h=n,
level=(one_sigma_rule,), # ask one std for the confidence interval.
)

mu, std = unpack_sf_dict(forecast_dict)
if num_samples > 1:
samples = create_normal_samples(mu, std, num_samples, n)
else:
samples = mu

return self._build_forecast_series(samples)

@property
def supports_multivariate(self) -> bool:
return False

@property
def min_train_series_length(self) -> int:
return 10

@property
def _supports_range_index(self) -> bool:
return True

@property
def supports_probabilistic_prediction(self) -> bool:
return True
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
StatsForecastAutoARIMA,
StatsForecastAutoCES,
StatsForecastAutoETS,
StatsForecastAutoTBATS,
StatsForecastAutoTheta,
Theta,
)
Expand All @@ -57,6 +58,7 @@
(StatsForecastAutoTheta(season_length=12), 5.5),
(StatsForecastAutoCES(season_length=12, model="Z"), 7.3),
(StatsForecastAutoETS(season_length=12, model="AAZ"), 7.3),
(StatsForecastAutoTBATS(season_length=12), 10),
(Croston(version="classic"), 23),
(Croston(version="tsb", alpha_d=0.1, alpha_p=0.1), 23),
(Theta(), 11),
Expand Down
1 change: 1 addition & 0 deletions docs/userguide/covariates.md
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ GFMs are models that can be trained on multiple target (and covariate) time seri
| [StatsforecastAutoETS](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.sf_auto_ets.html#darts.models.forecasting.sf_auto_ets.StatsForecastAutoETS) | || |
| [StatsforecastAutoCES](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.sf_auto_ces.html#darts.models.forecasting.sf_auto_ces.StatsForecastAutoCES) | | | |
| [BATS](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.tbats_model.html#darts.models.forecasting.tbats_model.BATS) and [TBATS](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.tbats_model.html#darts.models.forecasting.tbats_model.TBATS) | | | |
| [StatsForecastAutoTBATS](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.sf_auto_tbats.html#darts.models.forecasting.sf_auto_tbats.StatsForecastAutoTBATS) | | | |
| [Theta](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.theta.html#darts.models.forecasting.theta.Theta) and [FourTheta](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.theta.html#darts.models.forecasting.theta.FourTheta) | | | |
| [StatsForecastAutoTheta](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.sf_auto_theta.html#darts.models.forecasting.sf_auto_theta.StatsForecastAutoTheta) | | | |
| [Prophet](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.prophet_model.html#darts.models.forecasting.prophet_model.Prophet) | || |
Expand Down

0 comments on commit aad1440

Please sign in to comment.