Skip to content

Commit

Permalink
unify_quantile_and_level_in_predict
Browse files Browse the repository at this point in the history
  • Loading branch information
elephaint committed Oct 17, 2024
1 parent 0b980c0 commit 63984e6
Show file tree
Hide file tree
Showing 10 changed files with 1,417 additions and 310 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,4 @@ jobs:
uv pip install --system "numpy<2" ".[dev]"
- name: Tests
run: nbdev_test --do_print --timing --flags polars
run: nbdev_test --do_print --timing --n_workers 0 --flags polars
45 changes: 14 additions & 31 deletions nbs/common.base_model.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -393,23 +393,11 @@
" set(temporal_cols.tolist()) & set(self.hist_exog_list + self.futr_exog_list)\n",
" )\n",
" \n",
" def _set_quantile(self, **data_module_kwargs):\n",
" if \"quantile\" in data_module_kwargs:\n",
" supported_losses = (losses.IQLoss, losses.DistributionLoss, \n",
" losses.GMM, losses.PMM, losses.NBMM)\n",
" if not isinstance(self.loss, supported_losses):\n",
" raise Exception(\n",
" f\"Please train with one of {supported_losses} to make use of the quantile argument.\"\n",
" )\n",
" else:\n",
" self.quantile = data_module_kwargs[\"quantile\"]\n",
" data_module_kwargs.pop(\"quantile\")\n",
" self.loss.update_quantile(q=self.quantile)\n",
" elif isinstance(self.loss, losses.IQLoss):\n",
" self.quantile = 0.5\n",
" self.loss.update_quantile(q=self.quantile)\n",
"\n",
" return data_module_kwargs\n",
" def _set_quantiles(self, quantiles=None):\n",
" if quantiles is None and isinstance(self.loss, losses.IQLoss):\n",
" self.loss.update_quantile(q=[0.5])\n",
" elif hasattr(self.loss, 'update_quantile') and callable(self.loss.update_quantile):\n",
" self.loss.update_quantile(q=quantiles)\n",
"\n",
" def _fit_distributed(\n",
" self,\n",
Expand Down Expand Up @@ -1066,10 +1054,7 @@
" insample_y = self.scaler.scaler(mean, y_loc, y_scale)\n",
" \n",
" # Save predictions\n",
" if self.loss.predict_single_quantile:\n",
" y_hat = quants\n",
" else:\n",
" y_hat = torch.concat((mean.unsqueeze(-1), quants), axis=-1)\n",
" y_hat = torch.concat((mean.unsqueeze(-1), quants), axis=-1)\n",
"\n",
" if self.loss.return_params:\n",
" distr_args = torch.stack(distr_args, dim=-1)\n",
Expand Down Expand Up @@ -1108,12 +1093,8 @@
" if self.loss.is_distribution_output:\n",
" y_loc, y_scale = self._get_loc_scale(y_idx)\n",
" distr_args = self.loss.scale_decouple(output=output_batch, loc=y_loc, scale=y_scale)\n",
" if self.loss.predict_single_quantile:\n",
" _, _, quant = self.loss.sample(distr_args=distr_args)\n",
" y_hat = quant\n",
" else:\n",
" _, sample_mean, quants = self.loss.sample(distr_args=distr_args)\n",
" y_hat = torch.concat((sample_mean, quants), axis=-1)\n",
" _, sample_mean, quants = self.loss.sample(distr_args=distr_args)\n",
" y_hat = torch.concat((sample_mean, quants), axis=-1)\n",
"\n",
" if self.loss.return_params:\n",
" distr_args = torch.stack(distr_args, dim=-1)\n",
Expand Down Expand Up @@ -1337,7 +1318,7 @@
" )\n",
"\n",
" def predict(self, dataset, test_size=None, step_size=1,\n",
" random_seed=None, **data_module_kwargs):\n",
" random_seed=None, quantiles=None, **data_module_kwargs):\n",
" \"\"\" Predict.\n",
"\n",
" Neural network prediction with PL's `Trainer` execution of `predict_step`.\n",
Expand All @@ -1347,11 +1328,12 @@
" `test_size`: int=None, test size for temporal cross-validation.<br>\n",
" `step_size`: int=1, Step size between each window.<br>\n",
" `random_seed`: int=None, random_seed for pytorch initializer and numpy generators, overwrites model.__init__'s.<br>\n",
" `quantiles`: list of floats, optional (default=None), target quantiles to predict. <br>\n",
" `**data_module_kwargs`: PL's TimeSeriesDataModule args, see [documentation](https://pytorch-lightning.readthedocs.io/en/1.6.1/extensions/datamodules.html#using-a-datamodule).\n",
" \"\"\"\n",
" self._check_exog(dataset)\n",
" self._restart_seed(random_seed)\n",
" data_module_kwargs = self._set_quantile(**data_module_kwargs)\n",
" self._set_quantiles(quantiles)\n",
"\n",
" self.predict_step_size = step_size\n",
" self.decompose_forecast = False\n",
Expand All @@ -1377,7 +1359,7 @@
" fcsts = fcsts.reshape(-1, len(self.loss.output_names))\n",
" return fcsts\n",
"\n",
" def decompose(self, dataset, step_size=1, random_seed=None, **data_module_kwargs):\n",
" def decompose(self, dataset, step_size=1, random_seed=None, quantiles=None, **data_module_kwargs):\n",
" \"\"\" Decompose Predictions.\n",
"\n",
" Decompose the predictions through the network's layers.\n",
Expand All @@ -1386,13 +1368,14 @@
" **Parameters:**<br>\n",
" `dataset`: NeuralForecast's `TimeSeriesDataset`, see [documentation here](https://nixtla.github.io/neuralforecast/tsdataset.html).<br>\n",
" `step_size`: int=1, step size between each window of temporal data.<br>\n",
" `quantiles`: list of floats, optional (default=None), target quantiles to predict. <br>\n",
" `**data_module_kwargs`: PL's TimeSeriesDataModule args, see [documentation](https://pytorch-lightning.readthedocs.io/en/1.6.1/extensions/datamodules.html#using-a-datamodule).\n",
" \"\"\"\n",
" # Restart random seed\n",
" if random_seed is None:\n",
" random_seed = self.random_seed\n",
" torch.manual_seed(random_seed)\n",
" data_module_kwargs = self._set_quantile(**data_module_kwargs)\n",
" self._set_quantiles(quantiles)\n",
"\n",
" self.predict_step_size = step_size\n",
" self.decompose_forecast = True\n",
Expand Down
Loading

0 comments on commit 63984e6

Please sign in to comment.