Skip to content

Commit

Permalink
rework_conformal
Browse files Browse the repository at this point in the history
  • Loading branch information
elephaint committed Oct 17, 2024
1 parent 6f2272c commit a4c8b54
Show file tree
Hide file tree
Showing 5 changed files with 275 additions and 210 deletions.
34 changes: 24 additions & 10 deletions nbs/common.model_checks.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,7 @@
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The autoreload extension is already loaded. To reload it, use:\n",
" %reload_ext autoreload\n"
]
}
],
"outputs": [],
"source": [
"#| hide\n",
"%load_ext autoreload\n",
Expand Down Expand Up @@ -220,6 +211,29 @@
" except RuntimeError:\n",
" raise Exception(f\"{model_class.__name__}: AirPassengers forecast test failed.\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#| eval: false\n",
"#| hide\n",
"# Run tests in this file. This is a slow test\n",
"import warnings\n",
"import logging\n",
"from neuralforecast.models import RNN, GRU, TCN, LSTM, DeepAR, DilatedRNN, BiTCN, MLP, NBEATS, NBEATSx, NHITS, DLinear, NLinear, TiDE, DeepNPTS, TFT, VanillaTransformer, Informer, Autoformer, FEDformer, TimesNet, iTransformer, KAN, RMoK, StemGNN, TSMixer, TSMixerx, MLPMultivariate, SOFTS, TimeMixer\n",
"\n",
"models = [RNN, GRU, TCN, LSTM, DeepAR, DilatedRNN, BiTCN, MLP, NBEATS, NBEATSx, NHITS, DLinear, NLinear, TiDE, DeepNPTS, TFT, VanillaTransformer, Informer, Autoformer, FEDformer, TimesNet, iTransformer, KAN, RMoK, StemGNN, TSMixer, TSMixerx, MLPMultivariate, SOFTS, TimeMixer]\n",
"\n",
"logging.getLogger(\"pytorch_lightning\").setLevel(logging.ERROR)\n",
"logging.getLogger(\"lightning_fabric\").setLevel(logging.ERROR)\n",
"with warnings.catch_warnings():\n",
" warnings.simplefilter(\"ignore\")\n",
" for model in models:\n",
" check_model(model, checks=[\"losses\"])"
]
}
],
"metadata": {
Expand Down
113 changes: 69 additions & 44 deletions nbs/core.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -738,13 +738,14 @@
" names: List[str] = []\n",
" count_names = {'model': 0}\n",
" for model in self.models:\n",
" if add_level and (model.loss.outputsize_multiplier > 1 or isinstance(model.loss, IQLoss)):\n",
" continue\n",
"\n",
" model_name = repr(model)\n",
" count_names[model_name] = count_names.get(model_name, -1) + 1\n",
" if count_names[model_name] > 0:\n",
" model_name += str(count_names[model_name])\n",
"\n",
" if add_level and (model.loss.outputsize_multiplier > 1 or isinstance(model.loss, IQLoss)):\n",
" continue\n",
"\n",
" names.extend(model_name + n for n in model.loss.output_names)\n",
" return names\n",
"\n",
Expand Down Expand Up @@ -906,6 +907,7 @@
" raise Exception(\"You must fit the model before predicting.\")\n",
" \n",
" quantiles_ = None\n",
" level_ = None\n",
" has_level = False \n",
" if level is not None:\n",
" has_level = True\n",
Expand Down Expand Up @@ -1012,7 +1014,7 @@
" self._scalers_transform(futr_dataset)\n",
" dataset = dataset.append(futr_dataset)\n",
" \n",
" fcsts, cols = self._generate_forecasts(dataset=dataset, quantiles_=quantiles_, has_level=has_level, **data_kwargs)\n",
" fcsts, cols = self._generate_forecasts(dataset=dataset, uids=uids, quantiles_=quantiles_, level_=level_, has_level=has_level, **data_kwargs)\n",
" \n",
" if self.scalers_:\n",
" indptr = np.append(0, np.full(len(uids), self.h).cumsum())\n",
Expand All @@ -1028,26 +1030,26 @@
" _warn_id_as_idx()\n",
" fcsts_df = fcsts_df.set_index(self.id_col)\n",
"\n",
" # add prediction intervals or quantiles to models trained with point loss functions via level argument\n",
" if level is not None or quantiles is not None:\n",
" model_names = self._get_model_names(add_level=True)\n",
" if model_names:\n",
" if self.prediction_intervals is None:\n",
" raise AttributeError(\n",
" \"You have trained one or more models with a point loss function (e.g. MAE, MSE). \"\n",
" \"You then must set `prediction_intervals` during fit to use level or quantiles during predict.\") \n",
" prediction_interval_method = get_prediction_interval_method(self.prediction_intervals.method)\n",
"\n",
" fcsts_df = prediction_interval_method(\n",
" fcsts_df,\n",
" self._cs_df,\n",
" model_names=list(model_names),\n",
" level=level_ if level is not None else None,\n",
" cs_n_windows=self.prediction_intervals.n_windows,\n",
" n_series=len(uids),\n",
" horizon=self.h,\n",
" quantiles=quantiles_ if quantiles is not None else None,\n",
" ) \n",
" # # add prediction intervals or quantiles to models trained with point loss functions via level argument\n",
" # if level is not None or quantiles is not None:\n",
" # model_names = self._get_model_names(add_level=True)\n",
" # if model_names:\n",
" # if self.prediction_intervals is None:\n",
" # raise AttributeError(\n",
" # \"You have trained one or more models with a point loss function (e.g. MAE, MSE). \"\n",
" # \"You then must set `prediction_intervals` during fit to use level or quantiles during predict.\") \n",
" # prediction_interval_method = get_prediction_interval_method(self.prediction_intervals.method)\n",
"\n",
" # fcsts_df = prediction_interval_method(\n",
" # fcsts_df,\n",
" # self._cs_df,\n",
" # model_names=list(model_names),\n",
" # level=level_ if level is not None else None,\n",
" # cs_n_windows=self.prediction_intervals.n_windows,\n",
" # n_series=len(uids),\n",
" # horizon=self.h,\n",
" # quantiles=quantiles_ if quantiles is not None else None,\n",
" # ) \n",
"\n",
" return fcsts_df\n",
"\n",
Expand Down Expand Up @@ -1696,7 +1698,7 @@
" dropped = list(set(cv_results.columns) - set(kept))\n",
" return ufp.drop_columns(cv_results, dropped) \n",
" \n",
" def _generate_forecasts(self, dataset: TimeSeriesDataset, quantiles_: Optional[List[float]] = None, has_level: Optional[bool] = False, **data_kwargs) -> np.array:\n",
" def _generate_forecasts(self, dataset: TimeSeriesDataset, uids: Series, quantiles_: Optional[List[float]] = None, level_: Optional[List[Union[int, float]]] = None, has_level: Optional[bool] = False, **data_kwargs) -> np.array:\n",
" fcsts_list: List = []\n",
" cols = []\n",
" count_names = {'model': 0}\n",
Expand All @@ -1711,6 +1713,7 @@
" model_name += str(count_names[model_name])\n",
"\n",
" # Predict for every quantile or level if requested and the loss function supports it\n",
" # case 1: DistributionLoss and MixtureLosses\n",
" if quantiles_ is not None and not isinstance(model.loss, IQLoss) and hasattr(model.loss, 'update_quantile') and callable(model.loss.update_quantile):\n",
" model_fcsts = model.predict(dataset=dataset, quantiles = quantiles_, **data_kwargs)\n",
" fcsts_list.append(model_fcsts) \n",
Expand All @@ -1725,6 +1728,7 @@
" cols.extend(col_names + [model_name + param_name for param_name in model.loss.param_names])\n",
" else:\n",
" cols.extend(col_names)\n",
" # case 2: IQLoss\n",
" elif quantiles_ is not None and isinstance(model.loss, IQLoss):\n",
" col_names = []\n",
" for i, quantile in enumerate(quantiles_):\n",
Expand All @@ -1733,6 +1737,27 @@
" col_name = self._get_column_name(model_name, quantile, has_level)\n",
" col_names.extend([col_name]) \n",
" cols.extend(col_names)\n",
" # case 3: PointLoss via prediction intervals\n",
" elif quantiles_ is not None and model.loss.outputsize_multiplier == 1:\n",
" if self.prediction_intervals is None:\n",
" raise AttributeError(\n",
" f\"You have trained {model_name} with loss={type(model.loss).__name__}(). \\n\"\n",
" \" You then must set `prediction_intervals` during fit to use level or quantiles during predict.\") \n",
" model_fcsts = model.predict(dataset=dataset, quantiles = quantiles_, **data_kwargs)\n",
" prediction_interval_method = get_prediction_interval_method(self.prediction_intervals.method)\n",
" fcsts_with_intervals, out_cols = prediction_interval_method(\n",
" model_fcsts,\n",
" self._cs_df,\n",
" model=model_name,\n",
" level=level_ if has_level else None,\n",
" cs_n_windows=self.prediction_intervals.n_windows,\n",
" n_series=len(uids),\n",
" horizon=self.h,\n",
" quantiles=quantiles_ if not has_level else None,\n",
" ) \n",
" fcsts_list.append(fcsts_with_intervals) \n",
" cols.extend([model_name] + out_cols)\n",
" # base case: quantiles or levels are not supported or provided as arguments\n",
" else:\n",
" model_fcsts = model.predict(dataset=dataset, **data_kwargs)\n",
" fcsts_list.append(model_fcsts)\n",
Expand Down Expand Up @@ -3530,34 +3555,34 @@
"\n",
"nf = NeuralForecast(models=models, freq='M')\n",
"nf.fit(AirPassengersPanel_train, prediction_intervals=prediction_intervals)\n",
"# Test default prediction and correct columns\n",
"# Test default prediction\n",
"preds = nf.predict(futr_df=AirPassengersPanel_test)\n",
"assert list(preds.columns) == ['unique_id', 'ds', 'NHITS', 'NHITS1', 'NHITS1-median', 'NHITS1-lo-90',\n",
" 'NHITS1-lo-80', 'NHITS1-hi-80', 'NHITS1-hi-90', 'NHITS2_ql0.5', 'LSTM',\n",
" 'LSTM1', 'LSTM1-median', 'LSTM1-lo-90', 'LSTM1-lo-80', 'LSTM1-hi-80',\n",
" 'LSTM1-hi-90', 'LSTM2_ql0.5', 'TSMixer', 'TSMixer1', 'TSMixer1-median',\n",
" 'TSMixer1-lo-90', 'TSMixer1-lo-80', 'TSMixer1-hi-80', 'TSMixer1-hi-90',\n",
" 'TSMixer2_ql0.5']\n",
"# Test multiple quantile prediction and correct columns\n",
"# Test quantile prediction\n",
"preds = nf.predict(futr_df=AirPassengersPanel_test, quantiles=[0.2, 0.3])\n",
"assert list(preds.columns) == ['unique_id', 'ds', 'NHITS', 'NHITS1', 'NHITS1_ql0.2', 'NHITS1_ql0.3',\n",
" 'NHITS2_ql0.2', 'NHITS2_ql0.3', 'LSTM', 'LSTM1', 'LSTM1_ql0.2',\n",
" 'LSTM1_ql0.3', 'LSTM2_ql0.2', 'LSTM2_ql0.3', 'TSMixer', 'TSMixer1',\n",
" 'TSMixer1_ql0.2', 'TSMixer1_ql0.3', 'TSMixer2_ql0.2', 'TSMixer2_ql0.3',\n",
" 'NHITS-ql0.2', 'NHITS-ql0.3', 'LSTM-ql0.2', 'LSTM-ql0.3',\n",
" 'TSMixer-ql0.2', 'TSMixer-ql0.3']\n",
"# Test multiple level prediction and correct columns\n",
"assert list(preds.columns) == ['unique_id', 'ds', 'NHITS', 'NHITS-ql0.2', 'NHITS-ql0.3', 'NHITS1',\n",
" 'NHITS1_ql0.2', 'NHITS1_ql0.3', 'NHITS2_ql0.2', 'NHITS2_ql0.3', 'LSTM',\n",
" 'LSTM-ql0.2', 'LSTM-ql0.3', 'LSTM1', 'LSTM1_ql0.2', 'LSTM1_ql0.3',\n",
" 'LSTM2_ql0.2', 'LSTM2_ql0.3', 'TSMixer', 'TSMixer-ql0.2',\n",
" 'TSMixer-ql0.3', 'TSMixer1', 'TSMixer1_ql0.2', 'TSMixer1_ql0.3',\n",
" 'TSMixer2_ql0.2', 'TSMixer2_ql0.3']\n",
"# Test level prediction\n",
"preds = nf.predict(futr_df=AirPassengersPanel_test, level=[80, 90])\n",
"assert list(preds.columns) == ['unique_id', 'ds', 'NHITS', 'NHITS1', 'NHITS1-lo-90', 'NHITS1-lo-80',\n",
" 'NHITS1-hi-80', 'NHITS1-hi-90', 'NHITS2-lo-90', 'NHITS2-lo-80',\n",
" 'NHITS2-hi-80', 'NHITS2-hi-90', 'LSTM', 'LSTM1', 'LSTM1-lo-90',\n",
" 'LSTM1-lo-80', 'LSTM1-hi-80', 'LSTM1-hi-90', 'LSTM2-lo-90',\n",
" 'LSTM2-lo-80', 'LSTM2-hi-80', 'LSTM2-hi-90', 'TSMixer', 'TSMixer1',\n",
" 'TSMixer1-lo-90', 'TSMixer1-lo-80', 'TSMixer1-hi-80', 'TSMixer1-hi-90',\n",
" 'TSMixer2-lo-90', 'TSMixer2-lo-80', 'TSMixer2-hi-80', 'TSMixer2-hi-90',\n",
" 'NHITS-lo-90', 'NHITS-lo-80', 'NHITS-hi-80', 'NHITS-hi-90',\n",
" 'LSTM-lo-90', 'LSTM-lo-80', 'LSTM-hi-80', 'LSTM-hi-90', 'TSMixer-lo-90',\n",
" 'TSMixer-lo-80', 'TSMixer-hi-80', 'TSMixer-hi-90']\n",
"assert list(preds.columns) == ['unique_id', 'ds', 'NHITS', 'NHITS-lo-90', 'NHITS-lo-80', 'NHITS-hi-80',\n",
" 'NHITS-hi-90', 'NHITS1', 'NHITS1-lo-90', 'NHITS1-lo-80', 'NHITS1-hi-80',\n",
" 'NHITS1-hi-90', 'NHITS2-lo-90', 'NHITS2-lo-80', 'NHITS2-hi-80',\n",
" 'NHITS2-hi-90', 'LSTM', 'LSTM-lo-90', 'LSTM-lo-80', 'LSTM-hi-80',\n",
" 'LSTM-hi-90', 'LSTM1', 'LSTM1-lo-90', 'LSTM1-lo-80', 'LSTM1-hi-80',\n",
" 'LSTM1-hi-90', 'LSTM2-lo-90', 'LSTM2-lo-80', 'LSTM2-hi-80',\n",
" 'LSTM2-hi-90', 'TSMixer', 'TSMixer-lo-90', 'TSMixer-lo-80',\n",
" 'TSMixer-hi-80', 'TSMixer-hi-90', 'TSMixer1', 'TSMixer1-lo-90',\n",
" 'TSMixer1-lo-80', 'TSMixer1-hi-80', 'TSMixer1-hi-90', 'TSMixer2-lo-90',\n",
" 'TSMixer2-lo-80', 'TSMixer2-hi-80', 'TSMixer2-hi-90']\n",
"# Re-Test default prediction - note that they are different from the first test (this is expected)\n",
"preds = nf.predict(futr_df=AirPassengersPanel_test)\n",
"assert list(preds.columns) == ['unique_id', 'ds', 'NHITS', 'NHITS1', 'NHITS1-median', 'NHITS2_ql0.5',\n",
Expand Down
Loading

0 comments on commit a4c8b54

Please sign in to comment.