Skip to content

Commit

Permalink
Feat/optimized hfc torch (#2013)
Browse files Browse the repository at this point in the history
* move hfc optimization checks to GlobalForecastingModel

* setup optimization files for hfc with torch models

* adapt torch infrerence datasets to work with stride and bounds

* first working version

* adapt for overlap_end=True

* fix test

* make tests for integer indexed series

* make multiple ts work

* update documentation

* fix issue with regression model optim hfc

* fix basic sample comparison

* allow autoregression in optim hfc for torch models

* remove some unnecessary lines

* update changelog

* refactor hist fc forecastable index

* add unit test for exact end

* update docs

* apply suggestions from PR review
  • Loading branch information
dennisbader authored Oct 28, 2023
1 parent c777193 commit ea37dc9
Show file tree
Hide file tree
Showing 16 changed files with 1,055 additions and 290 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@ but cannot always guarantee backwards compatibility. Changes that may **break co

**Improved**
- Improvements to `TorchForecastingModel`:
- 🚀🚀 Optimized `historical_forecasts()` for pre-trained `TorchForecastingModel` running up to 20 times faster than before!. [#2013](https://github.com/unit8co/darts/pull/2013) by [Dennis Bader](https://github.com/dennisbader).
- Added callback `darts.utils.callbacks.TFMProgressBar` to customize at which model stages to display the progress bar. [#2020](https://github.com/unit8co/darts/pull/2020) by [Dennis Bader](https://github.com/dennisbader).
- Improvements to documentation:
- Adapted the example notebooks to properly apply data transformers and avoid look-ahead bias. [#2020](https://github.com/unit8co/darts/pull/2020) by [Samriddhi Singh](https://github.com/SimTheGreat).
- Adapted the example notebooks to properly apply data transformers and avoid look-ahead bias. [#2020](https://github.com/unit8co/darts/pull/2020) by [Samriddhi Singh](https://github.com/SimTheGreat).

**Fixed**
- Fixed a bug when calling optimized `historical_forecasts()` for a `RegressionModel` trained with unequal component-specific lags. [#2040](https://github.com/unit8co/darts/pull/2040) by [Antoine Madrona](https://github.com/madtoinou).
Expand Down
7 changes: 7 additions & 0 deletions darts/models/forecasting/ensemble_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,13 @@ def supports_future_covariates(self) -> bool:
[model.supports_future_covariates for model in self.forecasting_models]
)

@property
def supports_optimized_historical_forecasts(self) -> bool:
"""
Whether the model supports optimized historical forecasts
"""
return False

@property
def _supports_non_retrainable_historical_forecasts(self) -> bool:
return self.is_global_ensemble
Expand Down
57 changes: 26 additions & 31 deletions darts/models/forecasting/forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,7 @@ def _build_forecast_series(
custom_components: Union[List[str], None] = None,
with_static_covs: bool = True,
with_hierarchy: bool = True,
pred_start: Optional[Union[pd.Timestamp, int]] = None,
) -> TimeSeries:
"""
Builds a forecast time series starting after the end of the training time series, with the
Expand All @@ -504,6 +505,9 @@ def _build_forecast_series(
If set to False, do not copy the input_series `static_covariates` attribute
with_hierarchy
If set to False, do not copy the input_series `hierarchy` attribute
pred_start
Optionally, give a custom prediction start point.
Returns
-------
TimeSeries
Expand All @@ -518,6 +522,7 @@ def _build_forecast_series(
custom_components,
with_static_covs,
with_hierarchy,
pred_start,
)

def _historical_forecasts_sanity_checks(self, *args: Any, **kwargs: Any) -> None:
Expand Down Expand Up @@ -546,40 +551,18 @@ def _get_last_prediction_time(
overlap_end,
latest_possible_prediction_start,
):
# when overlap_end=True, we can simply use the precomputed last possible prediction start point
# if `overlap_end` is True, we can use the pre-computed latest possible first prediction point
if overlap_end:
return latest_possible_prediction_start

# (1) otherwise, we have to step `forecast_horizon` steps back.
# (2) additionally, we check whether the `latest_possible_prediction_start` was shifted back
# from the overall theoretical latest possible prediction start point (which is by definition
# the first time step after the end of the target series) due to too short covariates.
theoretical_latest_prediction_start = series.end_time() + series.freq
if latest_possible_prediction_start == theoretical_latest_prediction_start:
# (1)
last_valid_pred_time = series.time_index[-forecast_horizon]
else:
# (2)
covariates_shift = (
len(
generate_index(
start=latest_possible_prediction_start,
end=theoretical_latest_prediction_start,
freq=series.freq,
)
)
- 2
)
last_valid_pred_time = series.time_index[
-(forecast_horizon + covariates_shift)
]
return last_valid_pred_time
# otherwise, the upper bound for the last time step of the last prediction is the end of the target series
return series.time_index[-forecast_horizon]

def _check_optimizable_historical_forecasts(
self,
forecast_horizon: int,
retrain: Union[bool, int, Callable[..., bool]],
show_warnings=bool,
show_warnings: bool,
) -> bool:
"""By default, historical forecasts cannot be optimized"""
return False
Expand Down Expand Up @@ -863,7 +846,13 @@ def retrain_func(
# predictable time indexes (assuming model is already trained)
historical_forecasts_time_index_predict = (
_get_historical_forecast_predict_index(
model, series_, idx, past_covariates_, future_covariates_
model,
series_,
idx,
past_covariates_,
future_covariates_,
forecast_horizon,
overlap_end,
)
)

Expand All @@ -876,6 +865,8 @@ def retrain_func(
idx,
past_covariates_,
future_covariates_,
forecast_horizon,
overlap_end,
)
)

Expand Down Expand Up @@ -910,12 +901,9 @@ def retrain_func(

# based on `forecast_horizon` and `overlap_end`, historical_forecasts_time_index is shortened
historical_forecasts_time_index = _adjust_historical_forecasts_time_index(
model=model,
series=series_,
series_idx=idx,
historical_forecasts_time_index=historical_forecasts_time_index,
forecast_horizon=forecast_horizon,
overlap_end=overlap_end,
start=start,
start_format=start_format,
show_warnings=show_warnings,
Expand Down Expand Up @@ -1041,7 +1029,7 @@ def retrain_func(
else:
forecasts.append(forecast)

if last_points_only:
if last_points_only and last_points_values:
forecasts_list.append(
TimeSeries.from_times_and_values(
generate_index(
Expand Down Expand Up @@ -2265,6 +2253,13 @@ def _supports_non_retrainable_historical_forecasts(self) -> bool:
"""GlobalForecastingModel supports historical forecasts without retraining the model"""
return True

@property
def supports_optimized_historical_forecasts(self) -> bool:
"""
Whether the model supports optimized historical forecasts
"""
return True

def _sanity_check_predict_likelihood_parameters(
self, n: int, output_chunk_length: Union[int, None], num_samples: int
):
Expand Down
12 changes: 10 additions & 2 deletions darts/models/forecasting/pl_forecasting_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ def __init__(
self.pred_roll_size: Optional[int] = None
self.pred_batch_size: Optional[int] = None
self.pred_n_jobs: Optional[int] = None
self.predict_likelihood_parameters: Optional[bool] = None

@property
def first_prediction_index(self) -> int:
Expand Down Expand Up @@ -241,7 +242,11 @@ def predict_step(
dataloader_idx
the dataloader index
"""
input_data_tuple, batch_input_series = batch[:-1], batch[-1]
input_data_tuple, batch_input_series, batch_pred_starts = (
batch[:-2],
batch[-2],
batch[-1],
)

# number of individual series to be predicted in current batch
num_series = input_data_tuple[0].shape[0]
Expand Down Expand Up @@ -303,8 +308,11 @@ def predict_step(
else None,
with_static_covs=False if self.predict_likelihood_parameters else True,
with_hierarchy=False if self.predict_likelihood_parameters else True,
pred_start=pred_start,
)
for batch_idx, (input_series, pred_start) in enumerate(
zip(batch_input_series, batch_pred_starts)
)
for batch_idx, input_series in enumerate(batch_input_series)
)
return ts_forecasts

Expand Down
87 changes: 28 additions & 59 deletions darts/models/forecasting/regression_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,10 @@
create_lagged_training_data,
)
from darts.utils.historical_forecasts import (
_optimized_historical_forecasts_regression_all_points,
_optimized_historical_forecasts_regression_last_points_only,
_check_optimizable_historical_forecasts_global_models,
_optimized_historical_forecasts_all_points,
_optimized_historical_forecasts_last_points_only,
_process_historical_forecast_input,
)
from darts.utils.multioutput import MultiOutputRegressor
from darts.utils.utils import (
Expand Down Expand Up @@ -1068,40 +1070,23 @@ def supports_future_covariates(self) -> bool:
def supports_static_covariates(self) -> bool:
return True

@property
def supports_optimized_historical_forecasts(self) -> bool:
return True

def _check_optimizable_historical_forecasts(
self,
forecast_horizon: int,
retrain: Union[bool, int, Callable[..., bool]],
show_warnings=bool,
show_warnings: bool,
) -> bool:
"""
Historical forecast can be optimized only if `retrain=False` and `forecast_horizon <= self.output_chunk_length`
Historical forecast can be optimized only if `retrain=False` and `forecast_horizon <= model.output_chunk_length`
(no auto-regression required).
"""

supported_retrain = (retrain is False) or (retrain == 0)
supported_forecast_horizon = forecast_horizon <= self.output_chunk_length
if supported_retrain and supported_forecast_horizon:
return True

if show_warnings:
if not supported_retrain:
logger.warning(
"`enable_optimization=True` is ignored because `retrain` is not `False`"
"To hide this warning, set `show_warnings=False` or `enable_optimization=False`."
)
if not supported_forecast_horizon:
logger.warning(
"`enable_optimization=True` is ignored because "
"`forecast_horizon > self.output_chunk_length`."
"To hide this warning, set `show_warnings=False` or `enable_optimization=False`."
)

return False
return _check_optimizable_historical_forecasts_global_models(
model=self,
forecast_horizon=forecast_horizon,
retrain=retrain,
show_warnings=show_warnings,
allow_autoregression=False,
)

def _optimized_historical_forecasts(
self,
Expand All @@ -1122,41 +1107,25 @@ def _optimized_historical_forecasts(
TimeSeries, List[TimeSeries], Sequence[TimeSeries], Sequence[List[TimeSeries]]
]:
"""
For RegressionModels we create the lagged prediction data once per series using a moving window.
With this, we can avoid having to recreate the tabular input data and call `model.predict()` for each
forecastable index and series.
Additionally, there is a dedicated subroutines for `last_points_only=True` and `last_points_only=False`.
TODO: support forecast_horizon > output_chunk_length (auto-regression)
"""
if not self._fit_called:
raise_log(
ValueError("Model has not been fit yet."),
logger,
)
if forecast_horizon > self.output_chunk_length:
raise_log(
ValueError(
"`forecast_horizon > model.output_chunk_length` requires auto-regression which is not "
"supported in this optimized routine."
),
logger,
)

# manage covariates, usually handled by RegressionModel.predict()
if past_covariates is None and self.past_covariate_series is not None:
past_covariates = [self.past_covariate_series] * len(series)
if future_covariates is None and self.future_covariate_series is not None:
future_covariates = [self.future_covariate_series] * len(series)

self._verify_static_covariates(series[0].static_covariates)

if self.encoders.encoding_available:
past_covariates, future_covariates = self.generate_fit_predict_encodings(
n=forecast_horizon,
series=series,
past_covariates=past_covariates,
future_covariates=future_covariates,
)
series, past_covariates, future_covariates = _process_historical_forecast_input(
model=self,
series=series,
past_covariates=past_covariates,
future_covariates=future_covariates,
forecast_horizon=forecast_horizon,
allow_autoregression=False,
)

# TODO: move the loop here instead of duplicated code in each sub-routine?
if last_points_only:
return _optimized_historical_forecasts_regression_last_points_only(
return _optimized_historical_forecasts_last_points_only(
model=self,
series=series,
past_covariates=past_covariates,
Expand All @@ -1171,7 +1140,7 @@ def _optimized_historical_forecasts(
predict_likelihood_parameters=predict_likelihood_parameters,
)
else:
return _optimized_historical_forecasts_regression_all_points(
return _optimized_historical_forecasts_all_points(
model=self,
series=series,
past_covariates=past_covariates,
Expand Down
15 changes: 1 addition & 14 deletions darts/models/forecasting/theta.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""

import math
from typing import List, Optional, Tuple
from typing import List, Optional

import numpy as np
import statsmodels.tsa.holtwinters as hw
Expand Down Expand Up @@ -200,19 +200,6 @@ def min_train_series_length(self) -> int:
else:
return 3

@property
def extreme_lags(
self,
) -> Tuple[
Optional[int],
Optional[int],
Optional[int],
Optional[int],
Optional[int],
Optional[int],
]:
return -self.min_train_series_length, 0, None, None, None, None


class FourTheta(LocalForecastingModel):
def __init__(
Expand Down
Loading

0 comments on commit ea37dc9

Please sign in to comment.