diff --git a/darts/models/forecasting/forecasting_model.py b/darts/models/forecasting/forecasting_model.py index abbc17d514..6321689288 100644 --- a/darts/models/forecasting/forecasting_model.py +++ b/darts/models/forecasting/forecasting_model.py @@ -42,6 +42,7 @@ _get_historical_forecast_predict_index, _get_historical_forecast_train_index, _historical_forecasts_general_checks, + _historical_forecasts_sanitize_kwargs, _reconciliate_historical_time_indices, ) from darts.utils.timeseries_generation import ( @@ -816,10 +817,14 @@ def retrain_func( logger, ) - if fit_kwargs is None: - fit_kwargs = dict() - if predict_kwargs is None: - predict_kwargs = dict() + # remove unsupported arguments, raise exception if interference with historical forecasts logic + fit_kwargs, predict_kwargs = _historical_forecasts_sanitize_kwargs( + model=model, + fit_kwargs=fit_kwargs, + predict_kwargs=predict_kwargs, + retrain=retrain, + show_warnings=show_warnings, + ) series = series2seq(series) past_covariates = series2seq(past_covariates) diff --git a/darts/tests/models/forecasting/test_historical_forecasts.py b/darts/tests/models/forecasting/test_historical_forecasts.py index 93ff1e18d6..d1ce1f0fbf 100644 --- a/darts/tests/models/forecasting/test_historical_forecasts.py +++ b/darts/tests/models/forecasting/test_historical_forecasts.py @@ -333,6 +333,29 @@ class TestHistoricalforecast: # slightly longer to not affect the last predictable timestamp ts_covs = tg.gaussian_timeseries(length=30, start=start_ts) + @staticmethod + def create_model(ocl, use_ll=True, model_type="regression"): + if model_type == "regression": + return LinearRegressionModel( + lags=3, + likelihood="quantile" if use_ll else None, + quantiles=[0.05, 0.4, 0.5, 0.6, 0.95] if use_ll else None, + output_chunk_length=ocl, + ) + else: # model_type == "torch" + if not TORCH_AVAILABLE: + return None + return NLinearModel( + input_chunk_length=3, + likelihood=QuantileRegression([0.05, 0.4, 0.5, 0.6, 0.95]) + if use_ll + else None, + output_chunk_length=ocl, + n_epochs=1, + random_state=42, + **tfm_kwargs, + ) + def test_historical_forecasts_transferrable_future_cov_local_models(self): model = ARIMA() assert model.min_train_series_length == 30 @@ -1827,29 +1850,7 @@ def test_predict_likelihood_parameters(self, model_type): """standard checks that historical forecasts work with direct likelihood parameter predictions with regression and torch models.""" - def create_model(ocl, use_ll=True, model_type="regression"): - if model_type == "regression": - return LinearRegressionModel( - lags=3, - likelihood="quantile" if use_ll else None, - quantiles=[0.05, 0.4, 0.5, 0.6, 0.95] if use_ll else None, - output_chunk_length=ocl, - ) - else: # model_type == "torch" - if not TORCH_AVAILABLE: - return None - return NLinearModel( - input_chunk_length=3, - likelihood=QuantileRegression([0.05, 0.4, 0.5, 0.6, 0.95]) - if use_ll - else None, - output_chunk_length=ocl, - n_epochs=1, - random_state=42, - **tfm_kwargs, - ) - - model = create_model(1, False, model_type=model_type) + model = self.create_model(1, False, model_type=model_type) # skip torch models if not installed if model is None: return @@ -1860,7 +1861,7 @@ def create_model(ocl, use_ll=True, model_type="regression"): predict_likelihood_parameters=True, ) - model = create_model(1, model_type=model_type) + model = self.create_model(1, model_type=model_type) # forecast_horizon > output_chunk_length doesn't work with pytest.raises(ValueError): model.historical_forecasts( @@ -1869,7 +1870,7 @@ def create_model(ocl, use_ll=True, model_type="regression"): forecast_horizon=2, ) - model = create_model(1, model_type=model_type) + model = self.create_model(1, model_type=model_type) # num_samples != 1 doesn't work with pytest.raises(ValueError): model.historical_forecasts( @@ -1884,7 +1885,7 @@ def create_model(ocl, use_ll=True, model_type="regression"): qs_expected = ["q0.05", "q0.40", "q0.50", "q0.60", "q0.95"] qs_expected = pd.Index([target_name + "_" + q for q in qs_expected]) # check that it works with retrain - model = create_model(1, model_type=model_type) + model = self.create_model(1, model_type=model_type) hist_fc = model.historical_forecasts( self.ts_pass_train, predict_likelihood_parameters=True, @@ -1897,7 +1898,7 @@ def create_model(ocl, use_ll=True, model_type="regression"): assert len(hist_fc) == n # check for equal results between predict and hist fc without retraining - model = create_model(1, model_type=model_type) + model = self.create_model(1, model_type=model_type) model.fit(series=self.ts_pass_train[:-n]) hist_fc = model.historical_forecasts( self.ts_pass_train, @@ -1926,7 +1927,7 @@ def create_model(ocl, use_ll=True, model_type="regression"): # check equal results between predict and hist fc with higher output_chunk_length and horizon, # and last_points_only=False - model = create_model(2, model_type=model_type) + model = self.create_model(2, model_type=model_type) # we take one more training step so that model trained on ocl=1 has the same training samples # as model above model.fit(series=self.ts_pass_train[: -(n - 1)]) @@ -1959,3 +1960,101 @@ def create_model(ocl, use_ll=True, model_type="regression"): p.all_values(copy=False), hfc.all_values(copy=False) ) assert len(hist_fc) == n + 1 + + @pytest.mark.parametrize("model_type", ["regression", "torch"]) + def test_fit_kwargs(self, monkeypatch, model_type): + """check that the parameters provided in fit_kwargs are correctly processed""" + valid_fit_kwargs = {"max_samples_per_ts": 3} + invalid_fit_kwargs = {"series": self.ts_pass_train} + if model_type == "regression": + unsupported_fit_kwargs = {"trainer": None} + elif model_type == "torch": + unsupported_fit_kwargs = {"n_jobs_multioutput_wrapper": False} + + n = 2 + model = self.create_model(1, use_ll=False, model_type=model_type) + model.fit(series=self.ts_pass_train[:-n]) + + # supported argument + hist_fc = model.historical_forecasts( + self.ts_pass_train, + forecast_horizon=1, + num_samples=1, + start=len(self.ts_pass_train) - n, + retrain=True, + fit_kwargs=valid_fit_kwargs, + ) + + assert hist_fc.components.equals(self.ts_pass_train.components) + assert len(hist_fc) == n + + # passing unsupported argument + hist_fc = model.historical_forecasts( + self.ts_pass_train, + forecast_horizon=1, + start=len(self.ts_pass_train) - n, + retrain=True, + fit_kwargs=unsupported_fit_kwargs, + ) + + assert hist_fc.components.equals(self.ts_pass_train.components) + assert len(hist_fc) == n + + # passing hist_fc parameters in fit_kwargs, interferring with the logic + with pytest.raises(ValueError): + hist_fc = model.historical_forecasts( + self.ts_pass_train, + forecast_horizon=1, + start=len(self.ts_pass_train) - n, + retrain=True, + fit_kwargs=invalid_fit_kwargs, + ) + + @pytest.mark.parametrize("model_type", ["regression", "torch"]) + def test_predict_kwargs(self, monkeypatch, model_type): + """check that the parameters provided in predict_kwargs are correctly processed""" + invalid_predict_kwargs = {"predict_likelihood_parameters": False} + if model_type == "regression": + valid_predict_kwargs = {} + unsupported_predict_kwargs = {"batch_size": 10} + elif model_type == "torch": + valid_predict_kwargs = {"batch_size": 10} + unsupported_predict_kwargs = {} + + n = 2 + model = self.create_model(1, use_ll=False, model_type=model_type) + model.fit(series=self.ts_pass_train[:-n]) + + # supported argument + hist_fc = model.historical_forecasts( + self.ts_pass_train, + forecast_horizon=1, + start=len(self.ts_pass_train) - n, + retrain=False, + predict_kwargs=valid_predict_kwargs, + ) + + assert hist_fc.components.equals(self.ts_pass_train.components) + assert len(hist_fc) == n + + # passing unsupported argument + hist_fc = model.historical_forecasts( + self.ts_pass_train, + forecast_horizon=1, + start=len(self.ts_pass_train) - n, + retrain=False, + predict_kwargs=unsupported_predict_kwargs, + ) + + assert hist_fc.components.equals(self.ts_pass_train.components) + assert len(hist_fc) == n + + # passing hist_fc parameters in predict_kwargs, interferring with the logic + with pytest.raises(ValueError): + hist_fc = model.historical_forecasts( + self.ts_pass_train, + forecast_horizon=1, + start=len(self.ts_pass_train) - n, + retrain=False, + predict_kwargs=invalid_predict_kwargs, + ) diff --git a/darts/utils/historical_forecasts/utils.py b/darts/utils/historical_forecasts/utils.py index 5c299ad0d7..538a69bcf1 100644 --- a/darts/utils/historical_forecasts/utils.py +++ b/darts/utils/historical_forecasts/utils.py @@ -210,7 +210,16 @@ def _historical_forecasts_general_checks(model, series, kwargs): logger, ) - if n.fit_kwargs is not None or n.predict_kwargs is not None: + +def _historical_forecasts_sanitize_kwargs( + model, + fit_kwargs: Optional[Dict[str, Any]], + predict_kwargs: Optional[Dict[str, Any]], + retrain: bool, + show_warnings: bool, +) -> Tuple[Dict[str, Any], Dict[str, Any]]: + """Convert kwargs to dictionary, check that their content is compatible with called methods.""" + if fit_kwargs is not None or predict_kwargs is not None: hfc_args = set(inspect.signature(model.historical_forecasts).parameters) # replace `forecast_horizon` with `n` hfc_args = hfc_args - {"forecast_horizon"} @@ -222,57 +231,73 @@ def _historical_forecasts_general_checks(model, series, kwargs): "val_future_covariates", } - if n.fit_kwargs is not None: - if n.retrain: + if fit_kwargs is None: + fit_kwargs = dict() + else: + if retrain: fit_args = set(inspect.signature(model.fit).parameters) - _historical_forecasts_kwargs_checks( + fit_kwargs = _historical_forecasts_check_kwargs( hfc_args=hfc_args, name_kwargs="fit_kwargs", - dict_kwargs=n.fit_kwargs, + dict_kwargs=fit_kwargs, method_args=fit_args, - show_warnings=n.show_warnings, + show_warnings=show_warnings, ) - elif n.show_warnings: + elif show_warnings: logger.warning( "`fit_kwargs` was provided with `retrain=False`, the argument will be ignored." ) - if n.predict_kwargs is not None: + if predict_kwargs is None: + predict_kwargs = dict() + else: predict_args = set(inspect.signature(model.predict).parameters) - _historical_forecasts_kwargs_checks( + predict_kwargs = _historical_forecasts_check_kwargs( hfc_args=hfc_args, name_kwargs="predict_kwargs", - dict_kwargs=n.predict_kwargs, + dict_kwargs=predict_kwargs, method_args=predict_args, - show_warnings=n.show_warnings, + show_warnings=show_warnings, ) + return fit_kwargs, predict_kwargs -def _historical_forecasts_kwargs_checks( + +def _historical_forecasts_check_kwargs( hfc_args: Set[str], name_kwargs: str, dict_kwargs: Dict[str, Any], method_args: Set[str], show_warnings: bool, -): +) -> Dict[str, Any]: """ - Return a warning if some argument are not supported and an exception if some arguments interfere with - historical_forecasts logic + Return the kwargs dict without the arguments unsupported by the model method. + + Raise a warning if some argument are not supported and an exception if some arguments interfere with + historical_forecasts logic. """ - ignored_args = set(dict_kwargs) - method_args - if show_warnings and len(ignored_args) > 0: - logger.warning( - f"The following parameters in `{name_kwargs}` will be ignored was they are not supported by " - f"the model method : {ignored_args}." - ) invalid_args = set(dict_kwargs).intersection(hfc_args) if len(invalid_args) > 0: raise_log( - f"The following parameters cannot be passed using `{name_kwargs}` as they would interfere with " - f"historical forecasts logic : {invalid_args}.", + ValueError( + f"The following parameters cannot be passed using `{name_kwargs}` as they would interfere with " + f"historical forecasts logic : {invalid_args}." + ), logger, ) + ignored_args = set(dict_kwargs) - method_args + if len(ignored_args) > 0: + # remove unsupported argument to avoid exception thrown by python + dict_kwargs = {k: v for k, v in dict_kwargs.items() if k not in ignored_args} + if show_warnings: + logger.warning( + f"The following parameters in `{name_kwargs}` will be ignored was they are not supported by " + f"the model method : {ignored_args}." + ) + + return dict_kwargs + def _historical_forecasts_start_warnings( idx: int,