Skip to content

Commit

Permalink
feat: improve fit/predict_kwargs handling, add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
madtoinou committed Nov 10, 2023
1 parent c4ef3fc commit 986a524
Show file tree
Hide file tree
Showing 3 changed files with 184 additions and 55 deletions.
13 changes: 9 additions & 4 deletions darts/models/forecasting/forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
_get_historical_forecast_predict_index,
_get_historical_forecast_train_index,
_historical_forecasts_general_checks,
_historical_forecasts_sanitize_kwargs,
_reconciliate_historical_time_indices,
)
from darts.utils.timeseries_generation import (
Expand Down Expand Up @@ -816,10 +817,14 @@ def retrain_func(
logger,
)

if fit_kwargs is None:
fit_kwargs = dict()
if predict_kwargs is None:
predict_kwargs = dict()
# remove unsupported arguments, raise exception if interference with historical forecasts logic
fit_kwargs, predict_kwargs = _historical_forecasts_sanitize_kwargs(
model=model,
fit_kwargs=fit_kwargs,
predict_kwargs=predict_kwargs,
retrain=retrain,
show_warnings=show_warnings,
)

series = series2seq(series)
past_covariates = series2seq(past_covariates)
Expand Down
155 changes: 127 additions & 28 deletions darts/tests/models/forecasting/test_historical_forecasts.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,29 @@ class TestHistoricalforecast:
# slightly longer to not affect the last predictable timestamp
ts_covs = tg.gaussian_timeseries(length=30, start=start_ts)

@staticmethod
def create_model(ocl, use_ll=True, model_type="regression"):
if model_type == "regression":
return LinearRegressionModel(
lags=3,
likelihood="quantile" if use_ll else None,
quantiles=[0.05, 0.4, 0.5, 0.6, 0.95] if use_ll else None,
output_chunk_length=ocl,
)
else: # model_type == "torch"
if not TORCH_AVAILABLE:
return None
return NLinearModel(
input_chunk_length=3,
likelihood=QuantileRegression([0.05, 0.4, 0.5, 0.6, 0.95])
if use_ll
else None,
output_chunk_length=ocl,
n_epochs=1,
random_state=42,
**tfm_kwargs,
)

def test_historical_forecasts_transferrable_future_cov_local_models(self):
model = ARIMA()
assert model.min_train_series_length == 30
Expand Down Expand Up @@ -1827,29 +1850,7 @@ def test_predict_likelihood_parameters(self, model_type):
"""standard checks that historical forecasts work with direct likelihood parameter predictions
with regression and torch models."""

def create_model(ocl, use_ll=True, model_type="regression"):
if model_type == "regression":
return LinearRegressionModel(
lags=3,
likelihood="quantile" if use_ll else None,
quantiles=[0.05, 0.4, 0.5, 0.6, 0.95] if use_ll else None,
output_chunk_length=ocl,
)
else: # model_type == "torch"
if not TORCH_AVAILABLE:
return None
return NLinearModel(
input_chunk_length=3,
likelihood=QuantileRegression([0.05, 0.4, 0.5, 0.6, 0.95])
if use_ll
else None,
output_chunk_length=ocl,
n_epochs=1,
random_state=42,
**tfm_kwargs,
)

model = create_model(1, False, model_type=model_type)
model = self.create_model(1, False, model_type=model_type)
# skip torch models if not installed
if model is None:
return
Expand All @@ -1860,7 +1861,7 @@ def create_model(ocl, use_ll=True, model_type="regression"):
predict_likelihood_parameters=True,
)

model = create_model(1, model_type=model_type)
model = self.create_model(1, model_type=model_type)
# forecast_horizon > output_chunk_length doesn't work
with pytest.raises(ValueError):
model.historical_forecasts(
Expand All @@ -1869,7 +1870,7 @@ def create_model(ocl, use_ll=True, model_type="regression"):
forecast_horizon=2,
)

model = create_model(1, model_type=model_type)
model = self.create_model(1, model_type=model_type)
# num_samples != 1 doesn't work
with pytest.raises(ValueError):
model.historical_forecasts(
Expand All @@ -1884,7 +1885,7 @@ def create_model(ocl, use_ll=True, model_type="regression"):
qs_expected = ["q0.05", "q0.40", "q0.50", "q0.60", "q0.95"]
qs_expected = pd.Index([target_name + "_" + q for q in qs_expected])
# check that it works with retrain
model = create_model(1, model_type=model_type)
model = self.create_model(1, model_type=model_type)
hist_fc = model.historical_forecasts(
self.ts_pass_train,
predict_likelihood_parameters=True,
Expand All @@ -1897,7 +1898,7 @@ def create_model(ocl, use_ll=True, model_type="regression"):
assert len(hist_fc) == n

# check for equal results between predict and hist fc without retraining
model = create_model(1, model_type=model_type)
model = self.create_model(1, model_type=model_type)
model.fit(series=self.ts_pass_train[:-n])
hist_fc = model.historical_forecasts(
self.ts_pass_train,
Expand Down Expand Up @@ -1926,7 +1927,7 @@ def create_model(ocl, use_ll=True, model_type="regression"):

# check equal results between predict and hist fc with higher output_chunk_length and horizon,
# and last_points_only=False
model = create_model(2, model_type=model_type)
model = self.create_model(2, model_type=model_type)
# we take one more training step so that model trained on ocl=1 has the same training samples
# as model above
model.fit(series=self.ts_pass_train[: -(n - 1)])
Expand Down Expand Up @@ -1959,3 +1960,101 @@ def create_model(ocl, use_ll=True, model_type="regression"):
p.all_values(copy=False), hfc.all_values(copy=False)
)
assert len(hist_fc) == n + 1

@pytest.mark.parametrize("model_type", ["regression", "torch"])
def test_fit_kwargs(self, monkeypatch, model_type):
"""check that the parameters provided in fit_kwargs are correctly processed"""
valid_fit_kwargs = {"max_samples_per_ts": 3}
invalid_fit_kwargs = {"series": self.ts_pass_train}
if model_type == "regression":
unsupported_fit_kwargs = {"trainer": None}
elif model_type == "torch":
unsupported_fit_kwargs = {"n_jobs_multioutput_wrapper": False}

n = 2
model = self.create_model(1, use_ll=False, model_type=model_type)
model.fit(series=self.ts_pass_train[:-n])

# supported argument
hist_fc = model.historical_forecasts(
self.ts_pass_train,
forecast_horizon=1,
num_samples=1,
start=len(self.ts_pass_train) - n,
retrain=True,
fit_kwargs=valid_fit_kwargs,
)

assert hist_fc.components.equals(self.ts_pass_train.components)
assert len(hist_fc) == n

# passing unsupported argument
hist_fc = model.historical_forecasts(
self.ts_pass_train,
forecast_horizon=1,
start=len(self.ts_pass_train) - n,
retrain=True,
fit_kwargs=unsupported_fit_kwargs,
)

assert hist_fc.components.equals(self.ts_pass_train.components)
assert len(hist_fc) == n

# passing hist_fc parameters in fit_kwargs, interferring with the logic
with pytest.raises(ValueError):
hist_fc = model.historical_forecasts(
self.ts_pass_train,
forecast_horizon=1,
start=len(self.ts_pass_train) - n,
retrain=True,
fit_kwargs=invalid_fit_kwargs,
)

@pytest.mark.parametrize("model_type", ["regression", "torch"])
def test_predict_kwargs(self, monkeypatch, model_type):
"""check that the parameters provided in predict_kwargs are correctly processed"""
invalid_predict_kwargs = {"predict_likelihood_parameters": False}
if model_type == "regression":
valid_predict_kwargs = {}
unsupported_predict_kwargs = {"batch_size": 10}
elif model_type == "torch":
valid_predict_kwargs = {"batch_size": 10}
unsupported_predict_kwargs = {}

n = 2
model = self.create_model(1, use_ll=False, model_type=model_type)
model.fit(series=self.ts_pass_train[:-n])

# supported argument
hist_fc = model.historical_forecasts(
self.ts_pass_train,
forecast_horizon=1,
start=len(self.ts_pass_train) - n,
retrain=False,
predict_kwargs=valid_predict_kwargs,
)

assert hist_fc.components.equals(self.ts_pass_train.components)
assert len(hist_fc) == n

# passing unsupported argument
hist_fc = model.historical_forecasts(
self.ts_pass_train,
forecast_horizon=1,
start=len(self.ts_pass_train) - n,
retrain=False,
predict_kwargs=unsupported_predict_kwargs,
)

assert hist_fc.components.equals(self.ts_pass_train.components)
assert len(hist_fc) == n

# passing hist_fc parameters in predict_kwargs, interferring with the logic
with pytest.raises(ValueError):
hist_fc = model.historical_forecasts(
self.ts_pass_train,
forecast_horizon=1,
start=len(self.ts_pass_train) - n,
retrain=False,
predict_kwargs=invalid_predict_kwargs,
)
71 changes: 48 additions & 23 deletions darts/utils/historical_forecasts/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,16 @@ def _historical_forecasts_general_checks(model, series, kwargs):
logger,
)

if n.fit_kwargs is not None or n.predict_kwargs is not None:

def _historical_forecasts_sanitize_kwargs(
model,
fit_kwargs: Optional[Dict[str, Any]],
predict_kwargs: Optional[Dict[str, Any]],
retrain: bool,
show_warnings: bool,
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""Convert kwargs to dictionary, check that their content is compatible with called methods."""
if fit_kwargs is not None or predict_kwargs is not None:
hfc_args = set(inspect.signature(model.historical_forecasts).parameters)
# replace `forecast_horizon` with `n`
hfc_args = hfc_args - {"forecast_horizon"}
Expand All @@ -222,57 +231,73 @@ def _historical_forecasts_general_checks(model, series, kwargs):
"val_future_covariates",
}

if n.fit_kwargs is not None:
if n.retrain:
if fit_kwargs is None:
fit_kwargs = dict()
else:
if retrain:
fit_args = set(inspect.signature(model.fit).parameters)
_historical_forecasts_kwargs_checks(
fit_kwargs = _historical_forecasts_check_kwargs(
hfc_args=hfc_args,
name_kwargs="fit_kwargs",
dict_kwargs=n.fit_kwargs,
dict_kwargs=fit_kwargs,
method_args=fit_args,
show_warnings=n.show_warnings,
show_warnings=show_warnings,
)
elif n.show_warnings:
elif show_warnings:
logger.warning(
"`fit_kwargs` was provided with `retrain=False`, the argument will be ignored."
)

if n.predict_kwargs is not None:
if predict_kwargs is None:
predict_kwargs = dict()
else:
predict_args = set(inspect.signature(model.predict).parameters)
_historical_forecasts_kwargs_checks(
predict_kwargs = _historical_forecasts_check_kwargs(
hfc_args=hfc_args,
name_kwargs="predict_kwargs",
dict_kwargs=n.predict_kwargs,
dict_kwargs=predict_kwargs,
method_args=predict_args,
show_warnings=n.show_warnings,
show_warnings=show_warnings,
)

return fit_kwargs, predict_kwargs

def _historical_forecasts_kwargs_checks(

def _historical_forecasts_check_kwargs(
hfc_args: Set[str],
name_kwargs: str,
dict_kwargs: Dict[str, Any],
method_args: Set[str],
show_warnings: bool,
):
) -> Dict[str, Any]:
"""
Return a warning if some argument are not supported and an exception if some arguments interfere with
historical_forecasts logic
Return the kwargs dict without the arguments unsupported by the model method.
Raise a warning if some argument are not supported and an exception if some arguments interfere with
historical_forecasts logic.
"""
ignored_args = set(dict_kwargs) - method_args
if show_warnings and len(ignored_args) > 0:
logger.warning(
f"The following parameters in `{name_kwargs}` will be ignored was they are not supported by "
f"the model method : {ignored_args}."
)
invalid_args = set(dict_kwargs).intersection(hfc_args)
if len(invalid_args) > 0:
raise_log(
f"The following parameters cannot be passed using `{name_kwargs}` as they would interfere with "
f"historical forecasts logic : {invalid_args}.",
ValueError(
f"The following parameters cannot be passed using `{name_kwargs}` as they would interfere with "
f"historical forecasts logic : {invalid_args}."
),
logger,
)

ignored_args = set(dict_kwargs) - method_args
if len(ignored_args) > 0:
# remove unsupported argument to avoid exception thrown by python
dict_kwargs = {k: v for k, v in dict_kwargs.items() if k not in ignored_args}
if show_warnings:
logger.warning(
f"The following parameters in `{name_kwargs}` will be ignored was they are not supported by "
f"the model method : {ignored_args}."
)

return dict_kwargs


def _historical_forecasts_start_warnings(
idx: int,
Expand Down

0 comments on commit 986a524

Please sign in to comment.