From d9095894c895343fd86b54a1c667051671727083 Mon Sep 17 00:00:00 2001 From: madtoinou <32447896+madtoinou@users.noreply.github.com> Date: Mon, 11 Nov 2024 12:47:09 +0200 Subject: [PATCH] Fix/hfc opti reg prob (#2588) * fix: check that model is probabilistic when num samples is greater than 1 for optimized historical forecasts * feat: update the tests accordingly * update changelog * fix: simplify the test * fix: remove typo * fix: ignoring a linting commit for git blame --- .git-blame-ignore-revs | 2 + CHANGELOG.md | 1 + .../forecasting/test_regression_models.py | 42 +++++++++++++++++++ darts/utils/historical_forecasts/utils.py | 7 ++++ 4 files changed, 52 insertions(+) diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs index 996e88a867..b9998dd0d2 100644 --- a/.git-blame-ignore-revs +++ b/.git-blame-ignore-revs @@ -6,3 +6,5 @@ 38cc6712a6f701703074a7a7c82ce0252fe869ee # Fix last isort issues and update Black to 22.1.0 8158d3eaef9d9f6e04f219b029e306d1f1be46d5 +# Change Python target-version to 3.9 and update Ruff to 0.7.2 +18e2e3fd7d82d239ab24807fcc1033094ea09940 diff --git a/CHANGELOG.md b/CHANGELOG.md index 0ce70ae75a..66f6f90538 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co **Fixed** - Fixed a bug when using `darts.utils.data.tabularization.create_lagged_component_names()` with target `lags=None`, that did not return any lagged target label component names. [#2576](https://github.com/unit8co/darts/pull/2576) by [Dennis Bader](https://github.com/dennisbader). +- Fixed a bug when using `num_samples > 1` with a deterministic regression model and the optimized `historical_forecasts()` method, an exception was not raised. [#2576](https://github.com/unit8co/darts/pull/2588) by [Antoine Madrona](https://github.com/madtoinou). **Dependencies** diff --git a/darts/tests/models/forecasting/test_regression_models.py b/darts/tests/models/forecasting/test_regression_models.py index f865c94220..d6d1b6db11 100644 --- a/darts/tests/models/forecasting/test_regression_models.py +++ b/darts/tests/models/forecasting/test_regression_models.py @@ -1289,6 +1289,48 @@ def test_historical_forecast(self, mode): ) assert len(result) == 21 + def test_opti_historical_forecast_predict_checks(self): + """ + Verify that the sanity check implemented in ForecastingModel.predict are also defined for optimized historical + forecasts as it does not call this method + """ + model = self.models[1](lags=5) + + msg_expected = ( + "The model has not been fitted yet, and `retrain` is ``False``. Either call `fit()` before " + "`historical_forecasts()`, or set `retrain` to something different than ``False``." + ) + # untrained model, optimized + with pytest.raises(ValueError) as err: + model.historical_forecasts( + series=self.sine_univariate1, + start=0.9, + forecast_horizon=1, + retrain=False, + enable_optimization=True, + verbose=False, + ) + assert str(err.value) == msg_expected + + model.fit( + series=self.sine_univariate1, + ) + # deterministic model, num_samples > 1, optimized + with pytest.raises(ValueError) as err: + model.historical_forecasts( + series=self.sine_univariate1, + start=0.9, + forecast_horizon=1, + retrain=False, + enable_optimization=True, + num_samples=10, + verbose=False, + ) + assert ( + str(err.value) + == "`num_samples > 1` is only supported for probabilistic models." + ) + @pytest.mark.parametrize( "config", [ diff --git a/darts/utils/historical_forecasts/utils.py b/darts/utils/historical_forecasts/utils.py index 7481022517..cca6af103e 100644 --- a/darts/utils/historical_forecasts/utils.py +++ b/darts/utils/historical_forecasts/utils.py @@ -155,6 +155,13 @@ def _historical_forecasts_general_checks(model, series, kwargs): logger, ) + # duplication of ForecastingModel.predict() check for the optimized historical forecasts implementations + if not model.supports_probabilistic_prediction and n.num_samples > 1: + raise_log( + ValueError("`num_samples > 1` is only supported for probabilistic models."), + logger, + ) + # check direct likelihood parameter prediction before fitting a model if n.predict_likelihood_parameters: if not model.supports_likelihood_parameter_prediction: