Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/optimized hfc torch #2013

Merged
merged 20 commits into from
Oct 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@ but cannot always guarantee backwards compatibility. Changes that may **break co

**Improved**
- Improvements to `TorchForecastingModel`:
- 🚀🚀 Optimized `historical_forecasts()` for pre-trained `TorchForecastingModel` running up to 20 times faster than before!. [#2013](https://github.com/unit8co/darts/pull/2013) by [Dennis Bader](https://github.com/dennisbader).
- Added callback `darts.utils.callbacks.TFMProgressBar` to customize at which model stages to display the progress bar. [#2020](https://github.com/unit8co/darts/pull/2020) by [Dennis Bader](https://github.com/dennisbader).
- Improvements to documentation:
- Adapted the example notebooks to properly apply data transformers and avoid look-ahead bias. [#2020](https://github.com/unit8co/darts/pull/2020) by [Samriddhi Singh](https://github.com/SimTheGreat).
- Adapted the example notebooks to properly apply data transformers and avoid look-ahead bias. [#2020](https://github.com/unit8co/darts/pull/2020) by [Samriddhi Singh](https://github.com/SimTheGreat).

**Fixed**
- Fixed a bug when calling optimized `historical_forecasts()` for a `RegressionModel` trained with unequal component-specific lags. [#2040](https://github.com/unit8co/darts/pull/2040) by [Antoine Madrona](https://github.com/madtoinou).
Expand Down
7 changes: 7 additions & 0 deletions darts/models/forecasting/ensemble_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,13 @@ def supports_future_covariates(self) -> bool:
[model.supports_future_covariates for model in self.forecasting_models]
)

@property
def supports_optimized_historical_forecasts(self) -> bool:
"""
Whether the model supports optimized historical forecasts
"""
return False

@property
def _supports_non_retrainable_historical_forecasts(self) -> bool:
return self.is_global_ensemble
Expand Down
57 changes: 26 additions & 31 deletions darts/models/forecasting/forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,7 @@ def _build_forecast_series(
custom_components: Union[List[str], None] = None,
with_static_covs: bool = True,
with_hierarchy: bool = True,
pred_start: Optional[Union[pd.Timestamp, int]] = None,
) -> TimeSeries:
"""
Builds a forecast time series starting after the end of the training time series, with the
Expand All @@ -504,6 +505,9 @@ def _build_forecast_series(
If set to False, do not copy the input_series `static_covariates` attribute
with_hierarchy
If set to False, do not copy the input_series `hierarchy` attribute
pred_start
Optionally, give a custom prediction start point.

Returns
-------
TimeSeries
Expand All @@ -518,6 +522,7 @@ def _build_forecast_series(
custom_components,
with_static_covs,
with_hierarchy,
pred_start,
)

def _historical_forecasts_sanity_checks(self, *args: Any, **kwargs: Any) -> None:
Expand Down Expand Up @@ -546,40 +551,18 @@ def _get_last_prediction_time(
overlap_end,
latest_possible_prediction_start,
):
# when overlap_end=True, we can simply use the precomputed last possible prediction start point
# if `overlap_end` is True, we can use the pre-computed latest possible first prediction point
if overlap_end:
return latest_possible_prediction_start

# (1) otherwise, we have to step `forecast_horizon` steps back.
# (2) additionally, we check whether the `latest_possible_prediction_start` was shifted back
# from the overall theoretical latest possible prediction start point (which is by definition
# the first time step after the end of the target series) due to too short covariates.
theoretical_latest_prediction_start = series.end_time() + series.freq
if latest_possible_prediction_start == theoretical_latest_prediction_start:
# (1)
last_valid_pred_time = series.time_index[-forecast_horizon]
else:
# (2)
covariates_shift = (
len(
generate_index(
start=latest_possible_prediction_start,
end=theoretical_latest_prediction_start,
freq=series.freq,
)
)
- 2
)
last_valid_pred_time = series.time_index[
-(forecast_horizon + covariates_shift)
]
return last_valid_pred_time
# otherwise, the upper bound for the last time step of the last prediction is the end of the target series
return series.time_index[-forecast_horizon]

def _check_optimizable_historical_forecasts(
self,
forecast_horizon: int,
retrain: Union[bool, int, Callable[..., bool]],
show_warnings=bool,
show_warnings: bool,
) -> bool:
"""By default, historical forecasts cannot be optimized"""
return False
Expand Down Expand Up @@ -863,7 +846,13 @@ def retrain_func(
# predictable time indexes (assuming model is already trained)
historical_forecasts_time_index_predict = (
_get_historical_forecast_predict_index(
model, series_, idx, past_covariates_, future_covariates_
model,
series_,
idx,
past_covariates_,
future_covariates_,
forecast_horizon,
overlap_end,
)
)

Expand All @@ -876,6 +865,8 @@ def retrain_func(
idx,
past_covariates_,
future_covariates_,
forecast_horizon,
overlap_end,
)
)

Expand Down Expand Up @@ -910,12 +901,9 @@ def retrain_func(

# based on `forecast_horizon` and `overlap_end`, historical_forecasts_time_index is shortened
historical_forecasts_time_index = _adjust_historical_forecasts_time_index(
model=model,
series=series_,
series_idx=idx,
historical_forecasts_time_index=historical_forecasts_time_index,
forecast_horizon=forecast_horizon,
overlap_end=overlap_end,
start=start,
start_format=start_format,
show_warnings=show_warnings,
Expand Down Expand Up @@ -1041,7 +1029,7 @@ def retrain_func(
else:
forecasts.append(forecast)

if last_points_only:
if last_points_only and last_points_values:
forecasts_list.append(
TimeSeries.from_times_and_values(
generate_index(
Expand Down Expand Up @@ -2265,6 +2253,13 @@ def _supports_non_retrainable_historical_forecasts(self) -> bool:
"""GlobalForecastingModel supports historical forecasts without retraining the model"""
return True

@property
def supports_optimized_historical_forecasts(self) -> bool:
"""
Whether the model supports optimized historical forecasts
"""
return True

def _sanity_check_predict_likelihood_parameters(
self, n: int, output_chunk_length: Union[int, None], num_samples: int
):
Expand Down
12 changes: 10 additions & 2 deletions darts/models/forecasting/pl_forecasting_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ def __init__(
self.pred_roll_size: Optional[int] = None
self.pred_batch_size: Optional[int] = None
self.pred_n_jobs: Optional[int] = None
self.predict_likelihood_parameters: Optional[bool] = None

@property
def first_prediction_index(self) -> int:
Expand Down Expand Up @@ -241,7 +242,11 @@ def predict_step(
dataloader_idx
the dataloader index
"""
input_data_tuple, batch_input_series = batch[:-1], batch[-1]
input_data_tuple, batch_input_series, batch_pred_starts = (
batch[:-2],
batch[-2],
batch[-1],
)

# number of individual series to be predicted in current batch
num_series = input_data_tuple[0].shape[0]
Expand Down Expand Up @@ -303,8 +308,11 @@ def predict_step(
else None,
with_static_covs=False if self.predict_likelihood_parameters else True,
with_hierarchy=False if self.predict_likelihood_parameters else True,
pred_start=pred_start,
)
for batch_idx, (input_series, pred_start) in enumerate(
zip(batch_input_series, batch_pred_starts)
)
for batch_idx, input_series in enumerate(batch_input_series)
)
return ts_forecasts

Expand Down
87 changes: 28 additions & 59 deletions darts/models/forecasting/regression_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,10 @@
create_lagged_training_data,
)
from darts.utils.historical_forecasts import (
_optimized_historical_forecasts_regression_all_points,
_optimized_historical_forecasts_regression_last_points_only,
_check_optimizable_historical_forecasts_global_models,
_optimized_historical_forecasts_all_points,
_optimized_historical_forecasts_last_points_only,
_process_historical_forecast_input,
)
from darts.utils.multioutput import MultiOutputRegressor
from darts.utils.utils import (
Expand Down Expand Up @@ -1068,40 +1070,23 @@ def supports_future_covariates(self) -> bool:
def supports_static_covariates(self) -> bool:
return True

@property
def supports_optimized_historical_forecasts(self) -> bool:
return True

def _check_optimizable_historical_forecasts(
self,
forecast_horizon: int,
retrain: Union[bool, int, Callable[..., bool]],
show_warnings=bool,
show_warnings: bool,
) -> bool:
"""
Historical forecast can be optimized only if `retrain=False` and `forecast_horizon <= self.output_chunk_length`
Historical forecast can be optimized only if `retrain=False` and `forecast_horizon <= model.output_chunk_length`
(no auto-regression required).
"""

supported_retrain = (retrain is False) or (retrain == 0)
supported_forecast_horizon = forecast_horizon <= self.output_chunk_length
if supported_retrain and supported_forecast_horizon:
return True

if show_warnings:
if not supported_retrain:
logger.warning(
"`enable_optimization=True` is ignored because `retrain` is not `False`"
"To hide this warning, set `show_warnings=False` or `enable_optimization=False`."
)
if not supported_forecast_horizon:
logger.warning(
"`enable_optimization=True` is ignored because "
"`forecast_horizon > self.output_chunk_length`."
"To hide this warning, set `show_warnings=False` or `enable_optimization=False`."
)

return False
return _check_optimizable_historical_forecasts_global_models(
model=self,
forecast_horizon=forecast_horizon,
retrain=retrain,
show_warnings=show_warnings,
allow_autoregression=False,
)

def _optimized_historical_forecasts(
self,
Expand All @@ -1122,41 +1107,25 @@ def _optimized_historical_forecasts(
TimeSeries, List[TimeSeries], Sequence[TimeSeries], Sequence[List[TimeSeries]]
]:
"""
For RegressionModels we create the lagged prediction data once per series using a moving window.
With this, we can avoid having to recreate the tabular input data and call `model.predict()` for each
forecastable index and series.
Additionally, there is a dedicated subroutines for `last_points_only=True` and `last_points_only=False`.

TODO: support forecast_horizon > output_chunk_length (auto-regression)
"""
if not self._fit_called:
raise_log(
ValueError("Model has not been fit yet."),
logger,
)
if forecast_horizon > self.output_chunk_length:
raise_log(
ValueError(
"`forecast_horizon > model.output_chunk_length` requires auto-regression which is not "
"supported in this optimized routine."
),
logger,
)

# manage covariates, usually handled by RegressionModel.predict()
if past_covariates is None and self.past_covariate_series is not None:
past_covariates = [self.past_covariate_series] * len(series)
if future_covariates is None and self.future_covariate_series is not None:
future_covariates = [self.future_covariate_series] * len(series)

self._verify_static_covariates(series[0].static_covariates)

if self.encoders.encoding_available:
past_covariates, future_covariates = self.generate_fit_predict_encodings(
n=forecast_horizon,
series=series,
past_covariates=past_covariates,
future_covariates=future_covariates,
)
series, past_covariates, future_covariates = _process_historical_forecast_input(
model=self,
series=series,
past_covariates=past_covariates,
future_covariates=future_covariates,
forecast_horizon=forecast_horizon,
allow_autoregression=False,
)

# TODO: move the loop here instead of duplicated code in each sub-routine?
if last_points_only:
return _optimized_historical_forecasts_regression_last_points_only(
return _optimized_historical_forecasts_last_points_only(
model=self,
series=series,
past_covariates=past_covariates,
Expand All @@ -1171,7 +1140,7 @@ def _optimized_historical_forecasts(
predict_likelihood_parameters=predict_likelihood_parameters,
)
else:
return _optimized_historical_forecasts_regression_all_points(
return _optimized_historical_forecasts_all_points(
model=self,
series=series,
past_covariates=past_covariates,
Expand Down
15 changes: 1 addition & 14 deletions darts/models/forecasting/theta.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""

import math
from typing import List, Optional, Tuple
from typing import List, Optional

import numpy as np
import statsmodels.tsa.holtwinters as hw
Expand Down Expand Up @@ -200,19 +200,6 @@ def min_train_series_length(self) -> int:
else:
return 3

@property
def extreme_lags(
self,
) -> Tuple[
Optional[int],
Optional[int],
Optional[int],
Optional[int],
Optional[int],
Optional[int],
]:
return -self.min_train_series_length, 0, None, None, None, None


class FourTheta(LocalForecastingModel):
def __init__(
Expand Down
Loading