diff --git a/CHANGELOG.md b/CHANGELOG.md index 3bda881dc1..cf805cad84 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,8 @@ but cannot always guarantee backwards compatibility. Changes that may **break co **Improved** +- New model: `StatsForecastAutoTBATS`. This model offers the [AutoTBATS](https://nixtlaverse.nixtla.io/statsforecast/src/core/models.html#autotbats) model from Nixtla's `statsforecasts` library. [#2611](https://github.com/unit8co/darts/pull/2611) by [He Weilin](https://github.com/cnhwl). + **Fixed** **Dependencies** diff --git a/README.md b/README.md index e73723d3b0..46b04eb0f1 100644 --- a/README.md +++ b/README.md @@ -237,6 +237,7 @@ on bringing more models and features. | [StatsforecastAutoETS](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.sf_auto_ets.html#darts.models.forecasting.sf_auto_ets.StatsForecastAutoETS) | [Nixtla's statsforecast](https://github.com/Nixtla/statsforecast) | ✅ 🔴 | 🔴 ✅ 🔴 | ✅ 🔴 | 🔴 | | [StatsforecastAutoCES](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.sf_auto_ces.html#darts.models.forecasting.sf_auto_ces.StatsForecastAutoCES) | [Nixtla's statsforecast](https://github.com/Nixtla/statsforecast) | ✅ 🔴 | 🔴 🔴 🔴 | 🔴 🔴 | 🔴 | | [BATS](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.tbats_model.html#darts.models.forecasting.tbats_model.BATS) and [TBATS](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.tbats_model.html#darts.models.forecasting.tbats_model.TBATS) | [TBATS paper](https://robjhyndman.com/papers/ComplexSeasonality.pdf) | ✅ 🔴 | 🔴 🔴 🔴 | ✅ 🔴 | 🔴 | +| [StatsForecastAutoTBATS](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.sf_auto_tbats.html#darts.models.forecasting.sf_auto_tbats.StatsForecastAutoTBATS) | [Nixtla's statsforecast](https://github.com/Nixtla/statsforecast) | ✅ 🔴 | 🔴 🔴 🔴 | ✅ 🔴 | 🔴 | | [Theta](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.theta.html#darts.models.forecasting.theta.Theta) and [FourTheta](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.theta.html#darts.models.forecasting.theta.FourTheta) | [Theta](https://robjhyndman.com/papers/Theta.pdf) & [4 Theta](https://github.com/Mcompetitions/M4-methods/blob/master/4Theta%20method.R) | ✅ 🔴 | 🔴 🔴 🔴 | 🔴 🔴 | 🔴 | | [StatsForecastAutoTheta](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.sf_auto_theta.html#darts.models.forecasting.sf_auto_theta.StatsForecastAutoTheta) | [Nixtla's statsforecast](https://github.com/Nixtla/statsforecast) | ✅ 🔴 | 🔴 🔴 🔴 | ✅ 🔴 | 🔴 | | [Prophet](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.prophet_model.html#darts.models.forecasting.prophet_model.Prophet) | [Prophet repo](https://github.com/facebook/prophet) | ✅ 🔴 | 🔴 ✅ 🔴 | ✅ 🔴 | 🔴 | diff --git a/darts/models/__init__.py b/darts/models/__init__.py index 1ea802be3a..bfbe716b54 100644 --- a/darts/models/__init__.py +++ b/darts/models/__init__.py @@ -94,6 +94,7 @@ from darts.models.forecasting.sf_auto_arima import StatsForecastAutoARIMA from darts.models.forecasting.sf_auto_ces import StatsForecastAutoCES from darts.models.forecasting.sf_auto_ets import StatsForecastAutoETS + from darts.models.forecasting.sf_auto_tbats import StatsForecastAutoTBATS from darts.models.forecasting.sf_auto_theta import StatsForecastAutoTheta except ImportError: @@ -108,6 +109,7 @@ StatsForecastAutoCES = NotImportedModule(module_name="StatsForecast", warn=False) StatsForecastAutoETS = NotImportedModule(module_name="StatsForecast", warn=False) StatsForecastAutoTheta = NotImportedModule(module_name="StatsForecast", warn=False) + StatsForecastAutoTBATS = NotImportedModule(module_name="StatsForecast", warn=False) try: from darts.models.forecasting.xgboost import XGBModel @@ -160,6 +162,7 @@ "StatsForecastAutoCES", "StatsForecastAutoETS", "StatsForecastAutoTheta", + "StatsForecastAutoTBATS", "XGBModel", "GaussianProcessFilter", "KalmanFilter", diff --git a/darts/models/forecasting/__init__.py b/darts/models/forecasting/__init__.py index b3559f9b62..85ad3d8730 100644 --- a/darts/models/forecasting/__init__.py +++ b/darts/models/forecasting/__init__.py @@ -21,6 +21,7 @@ - :class:`~darts.models.forecasting.sf_auto_ces.StatsForecastAutoCES` - :class:`~darts.models.forecasting.tbats_model.BATS` - :class:`~darts.models.forecasting.tbats_model.TBATS` + - :class:`~darts.models.forecasting.sf_auto_tbats.StatsForecastAutoTBATS` - :class:`~darts.models.forecasting.theta.Theta` - :class:`~darts.models.forecasting.theta.FourTheta` - :class:`~darts.models.forecasting.sf_auto_theta.StatsForecastAutoTheta` diff --git a/darts/models/forecasting/sf_auto_tbats.py b/darts/models/forecasting/sf_auto_tbats.py new file mode 100644 index 0000000000..7e1bc16746 --- /dev/null +++ b/darts/models/forecasting/sf_auto_tbats.py @@ -0,0 +1,104 @@ +""" +StatsForecastAutoTBATS +----------- +""" + +from statsforecast.models import AutoTBATS as SFAutoTBATS + +from darts import TimeSeries +from darts.models.components.statsforecast_utils import ( + create_normal_samples, + one_sigma_rule, + unpack_sf_dict, +) +from darts.models.forecasting.forecasting_model import LocalForecastingModel + + +class StatsForecastAutoTBATS(LocalForecastingModel): + def __init__(self, *autoTBATS_args, **autoTBATS_kwargs): + """Auto-TBATS based on `Statsforecasts package + `_. + + Automatically selects the best TBATS model from all feasible combinations of the parameters `use_boxcox`, + `use_trend`, `use_damped_trend`, and `use_arma_errors`. Selection is made using the AIC. + Default value for `use_arma_errors` is True since this enables the evaluation of models with + and without ARMA errors. + + + + We refer to the `statsforecast AutoTBATS documentation + `_ + for the exhaustive documentation of the arguments. + + Parameters + ---------- + autoTBATS_args + Positional arguments for ``statsforecasts.models.AutoTBATS``. + autoTBATS_kwargs + Keyword arguments for ``statsforecasts.models.AutoTBATS``. + + Examples + -------- + >>> from darts.datasets import AirPassengersDataset + >>> from darts.models import StatsForecastAutoTBATS + >>> series = AirPassengersDataset().load() + >>> # define StatsForecastAutoTBATS parameters + >>> model = StatsForecastAutoTBATS(season_length=12) + >>> model.fit(series) + >>> pred = model.predict(6) + >>> pred.values() + array([[450.79653684], + [472.09265790], + [497.76948306], + [510.74927369], + [520.92224557], + [570.33881522]]) + """ + super().__init__() + self.model = SFAutoTBATS(*autoTBATS_args, **autoTBATS_kwargs) + + def fit(self, series: TimeSeries): + super().fit(series) + self._assert_univariate(series) + series = self.training_series + self.model.fit( + series.values(copy=False).flatten(), + ) + return self + + def predict( + self, + n: int, + num_samples: int = 1, + verbose: bool = False, + show_warnings: bool = True, + ): + super().predict(n, num_samples) + forecast_dict = self.model.predict( + h=n, + level=(one_sigma_rule,), # ask one std for the confidence interval. + ) + + mu, std = unpack_sf_dict(forecast_dict) + if num_samples > 1: + samples = create_normal_samples(mu, std, num_samples, n) + else: + samples = mu + + return self._build_forecast_series(samples) + + @property + def supports_multivariate(self) -> bool: + return False + + @property + def min_train_series_length(self) -> int: + return 10 + + @property + def _supports_range_index(self) -> bool: + return True + + @property + def supports_probabilistic_prediction(self) -> bool: + return True diff --git a/darts/tests/models/forecasting/test_local_forecasting_models.py b/darts/tests/models/forecasting/test_local_forecasting_models.py index e1e7361a60..186c93a813 100644 --- a/darts/tests/models/forecasting/test_local_forecasting_models.py +++ b/darts/tests/models/forecasting/test_local_forecasting_models.py @@ -35,6 +35,7 @@ StatsForecastAutoARIMA, StatsForecastAutoCES, StatsForecastAutoETS, + StatsForecastAutoTBATS, StatsForecastAutoTheta, Theta, ) @@ -57,6 +58,7 @@ (StatsForecastAutoTheta(season_length=12), 5.5), (StatsForecastAutoCES(season_length=12, model="Z"), 7.3), (StatsForecastAutoETS(season_length=12, model="AAZ"), 7.3), + (StatsForecastAutoTBATS(season_length=12), 10), (Croston(version="classic"), 23), (Croston(version="tsb", alpha_d=0.1, alpha_p=0.1), 23), (Theta(), 11), diff --git a/docs/userguide/covariates.md b/docs/userguide/covariates.md index 8df7dc94eb..c393594360 100644 --- a/docs/userguide/covariates.md +++ b/docs/userguide/covariates.md @@ -133,6 +133,7 @@ GFMs are models that can be trained on multiple target (and covariate) time seri | [StatsforecastAutoETS](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.sf_auto_ets.html#darts.models.forecasting.sf_auto_ets.StatsForecastAutoETS) | | ✅ | | | [StatsforecastAutoCES](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.sf_auto_ces.html#darts.models.forecasting.sf_auto_ces.StatsForecastAutoCES) | | | | | [BATS](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.tbats_model.html#darts.models.forecasting.tbats_model.BATS) and [TBATS](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.tbats_model.html#darts.models.forecasting.tbats_model.TBATS) | | | | +| [StatsForecastAutoTBATS](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.sf_auto_tbats.html#darts.models.forecasting.sf_auto_tbats.StatsForecastAutoTBATS) | | | | | [Theta](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.theta.html#darts.models.forecasting.theta.Theta) and [FourTheta](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.theta.html#darts.models.forecasting.theta.FourTheta) | | | | | [StatsForecastAutoTheta](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.sf_auto_theta.html#darts.models.forecasting.sf_auto_theta.StatsForecastAutoTheta) | | | | | [Prophet](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.prophet_model.html#darts.models.forecasting.prophet_model.Prophet) | | ✅ | |