From b15d8bbe0ce7120298ef01da043d80adbee70815 Mon Sep 17 00:00:00 2001 From: madtoinou Date: Wed, 26 Oct 2022 18:58:42 +0200 Subject: [PATCH 01/18] implementation of the deeptime forecasting model, creation of the associated tests (inspired from the nbeats tests) --- darts/models/forecasting/deeptime.py | 497 ++++++++++++++++++ .../tests/models/forecasting/test_deeptime.py | 185 +++++++ 2 files changed, 682 insertions(+) create mode 100644 darts/models/forecasting/deeptime.py create mode 100644 darts/tests/models/forecasting/test_deeptime.py diff --git a/darts/models/forecasting/deeptime.py b/darts/models/forecasting/deeptime.py new file mode 100644 index 0000000000..41251c945a --- /dev/null +++ b/darts/models/forecasting/deeptime.py @@ -0,0 +1,497 @@ +""" +DeepTime +------- +""" + +from typing import List, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + +from darts.logging import get_logger, raise_if_not +from darts.models.forecasting.pl_forecasting_module import PLPastCovariatesModule +from darts.models.forecasting.torch_forecasting_model import PastCovariatesTorchModel +from darts.utils.torch import MonteCarloDropout + +logger = get_logger(__name__) + + +ACTIVATIONS = [ + "ReLU", + "RReLU", + "PReLU", + "ELU", + "Softplus", + "Tanh", + "SELU", + "LeakyReLU", + "Sigmoid", + "GELU", +] + + +class GaussianFourierFeatureTransform(nn.Module): + """ + https://github.com/ndahlquist/pytorch-fourier-feature-networks + Given an input of size [..., time, dim], returns a tensor of size [..., n_fourier_feats, time]. + """ + + def __init__(self, input_dim: int, n_fourier_feats: int, scales: List[float]): + super().__init__() + self.input_dim = input_dim + self.n_fourier_feats = n_fourier_feats + self.scales = scales + + n_scale_feats = n_fourier_feats // (2 * len(scales)) + B_size = (input_dim, n_scale_feats) + # Sample Fourier components + B = torch.cat([torch.randn(B_size) * scale for scale in scales], dim=1) + self.register_buffer("B", B) + + def forward(self, x: Tensor) -> Tensor: + raise_if_not( + x.dim() >= 2, + f"Expected 2 or more dimensional input (got {x.dim()}D input)", + logger, + ) + dim = x.shape[-1] + raise_if_not( + dim == self.input_dim, + f"Expected input to have {self.input_dim} channels (got {dim} channels)", + logger, + ) + + x = torch.einsum("... t n, n d -> ... t d", [x, self.B]) + x = 2 * np.pi * x + return torch.cat([torch.sin(x), torch.cos(x)], dim=-1) + + +class INR(nn.Module): + def __init__( + self, + input_size: int, + num_layers: int, + hidden_layers_width: int, + n_fourier_feats: int, + scales: list[float], + dropout: float, + activation: str, + ): + """Implicit Neural Representation, mapping values to their coordinates using a Multi-Layer Perceptron. + + Features can be encoded using either a Linear layer or a Gaussian Fourier Transform + + Parameters + ---------- + input_size + The dimensionality of the input time series. + num_layers + The number of fully connected layers. + hidden_layers_width + Determines the number of neurons that make up each hidden fully connected layer. + If a list is passed, it must have a length equal to `num_layers`. If an integer is passed, + every layers will have the same width. + n_fourier_feats + Number of Fourier components to sample to represent to time-serie in the frequency domain + scales + Scaling factors applied to the normal distribution sampled for Fourier components' magnitude + dropout + The fraction of neurons that are dropped at each layer. + activation + The activation function of fully connected network intermediate layers. + + Inputs + ------ + x of shape `(batch_size, input_size)` + Tensor containing the features of the input sequence. + + Outputs + ------- + y of shape `(batch_size, output_chunk_length, target_size, nr_params)` + Tensor containing the prediction at the last time step of the sequence. + """ + super().__init__() + + self.input_size = input_size + self.num_layers = num_layers + self.n_fourier_feats = n_fourier_feats + self.scales = scales + self.dropout = dropout + + raise_if_not( + activation in ACTIVATIONS, f"'{activation}' is not in {ACTIVATIONS}" + ) + self.activation = getattr(nn, activation)() + + if isinstance(hidden_layers_width, int): + self.hidden_layers_width = [hidden_layers_width] * self.num_layers + else: + self.hidden_layers_width = hidden_layers_width + + if n_fourier_feats == 0: + feats_size = self.hidden_layers_width[0] + self.features = nn.Linear(self.input_size, feats_size) + else: + feats_size = self.n_fourier_feats + self.features = GaussianFourierFeatureTransform( + self.input_size, feats_size, self.scales + ) + + # Fully Connected Network + # TODO : solve ambiguity between the num of layers and the number of hidden layers + last_width = feats_size + linear_layer_stack_list = [] + for layer_width in self.hidden_layers_width: + linear_layer_stack_list.append(nn.Linear(last_width, layer_width)) + linear_layer_stack_list.append(self.activation) + + if self.dropout > 0: + linear_layer_stack_list.append(MonteCarloDropout(p=self.dropout)) + + linear_layer_stack_list.append(nn.LayerNorm(layer_width)) + + last_width = layer_width + + self.layers = nn.Sequential(*linear_layer_stack_list) + + def forward(self, x: Tensor) -> Tensor: + x = self.features(x) + return self.layers(x) + + +class RidgeRegressor(nn.Module): + def __init__(self, lambda_init: float = 0.0): + super().__init__() + self._lambda = nn.Parameter(torch.as_tensor(lambda_init, dtype=torch.float)) + + def forward(self, reprs: Tensor, x: Tensor, reg_coeff: float = None) -> Tensor: + if reg_coeff is None: + reg_coeff = self.reg_coeff() + w, b = self.get_weights(reprs, x, reg_coeff) + return w, b + + def get_weights(self, X: Tensor, Y: Tensor, reg_coeff: float) -> Tensor: + batch_size, n_samples, n_dim = X.shape + ones = torch.ones(batch_size, n_samples, 1, device=X.device) + X = torch.concat([X, ones], dim=-1) + + if n_samples >= n_dim: + # standard + A = torch.bmm(X.mT, X) + A.diagonal(dim1=-2, dim2=-1).add_(reg_coeff) + B = torch.bmm(X.mT, Y) + weights = torch.linalg.solve(A, B) + else: + # Woodbury + A = torch.bmm(X, X.mT) + A.diagonal(dim1=-2, dim2=-1).add_(reg_coeff) + weights = torch.bmm(X.mT, torch.linalg.solve(A, Y)) + + return weights[:, :-1], weights[:, -1:] + + def reg_coeff(self) -> Tensor: + return F.softplus(self._lambda) + + +class _DeepTimeModule(PLPastCovariatesModule): + def __init__( + self, + input_dim: int, + output_dim: int, + nr_params: int, + inr_num_layers: int, + inr_layers_width: Union[int, List[int]], + n_fourier_feats: int, + scales: list[float], + dropout: float, + activation: str, + **kwargs, + ): + """PyTorch module implementing the DeepTIMe architecture. + + Parameters + ---------- + input_chunk_length + The length of the input sequence (lookback) fed to the model. + output_chunk_length + The length of the forecast (horizon) of the model. + inr_num_layers + The number of fully connected layers in the INR module. + inr_layers_width + Determines the number of neurons that make up each hidden fully connected layer of the INR module. + If a list is passed, it must have a length equal to `num_layers`. If an integer is passed, + every layers will have the same width. + n_fourier_feats + Number of Fourier components to sample to represent to time-serie in the frequency domain + scales + Scaling factors applied to the normal distribution sampled for Fourier components' magnitude + dropout + The dropout probability to be used in fully connected layers (default=0). This is compatible with + Monte Carlo dropout at inference time for model uncertainty estimation (enabled with + ``mc_dropout=True`` at prediction time). + activation + The activation function of encoder/decoder intermediate layer (default='ReLU'). + Supported activations: ['ReLU','RReLU', 'PReLU', 'Softplus', 'Tanh', 'SELU', 'LeakyReLU', 'Sigmoid'] + **kwargs + Optional arguments to initialize the pytorch_lightning.Module, pytorch_lightning.Trainer, and + Darts' :class:`TorchForecastingModel`. + """ + super().__init__(**kwargs) + self.input_dim = input_dim + self.output_dim = output_dim + # TODO: develop support for nr_params functionality + self.nr_params = nr_params + self.inr_num_layers = inr_num_layers + self.inr_layers_width = inr_layers_width + self.n_fourier_feats = n_fourier_feats + self.scales = scales + + self.dropout = dropout + self.activation = activation + + # single INR for both univariate and multivariate + self.inr = INR( + input_size=self.input_chunk_length, + num_layers=self.inr_num_layers, + hidden_layers_width=self.inr_layers_width, + n_fourier_feats=self.n_fourier_feats, + scales=self.scales, + dropout=self.dropout, + activation=self.activation, + ) + + # lambda_ constructor argument should be connected to torch lightning + self.adaptive_weights = RidgeRegressor() + + def forward(self, x_in: Tensor) -> Tensor: + x, _ = x_in # x_in: (past_target|past_covariate, static_covariates) + + batch_size, _, in_dim = x.shape # x: (batch, in_len, in_dim) + coords = self.get_coords(self.input_chunk_length, self.output_chunk_length).to( + x.device + ) + + """ + I don't understand why this code is there... Why would the time have last dim greater than 0?! + if y_time.shape[-1] != 0: + time = torch.cat([x_time, y_time], dim=1) + coords = repeat(coords, '1 t 1 -> b t 1', b=time.shape[0]) + coords = torch.cat([coords, time], dim=-1) + time_reprs = self.inr(coords) + """ + + # multivariate + if self.output_dim > 1: + time = torch.zeros( + size=( + batch_size, + self.input_chunk_length + self.output_chunk_length, + self.output_dim, + ) + ).to(x.device) + coords = coords.repeat(batch_size, 1, 1) + coords = torch.cat([coords, time], dim=-1) + time_reprs = self.inr(coords) + # univariate + else: + time_reprs = self.inr(coords) + time_reprs = time_reprs.repeat(batch_size, 1, 1) + + lookback_reprs = time_reprs[:, : -self.output_chunk_length] + horizon_reprs = time_reprs[:, -self.output_chunk_length :] + # learn weights from the lookback + w, b = self.adaptive_weights(lookback_reprs, x) + # apply weights to the horizon + y = torch.einsum("... d o, ... t d -> ... t o", [w, horizon_reprs]) + b + + y = y.view( + y.shape[0], self.output_chunk_length, self.input_dim, self.nr_params + )[:, :, : self.output_dim, :] + return y + + def get_coords(self, lookback_len: int, horizon_len: int) -> Tensor: + """Return time axis encoded as float values between 0 and 1""" + coords = torch.linspace(0, 1, lookback_len + horizon_len) + # would it be clearer with the unsqueeze operator? + return torch.reshape(coords, (1, lookback_len + horizon_len, 1)) + + +class DeepTimeModel(PastCovariatesTorchModel): + def __init__( + self, + input_chunk_length: int, + output_chunk_length: int, + inr_num_layers: int = 5, + inr_layers_width: Union[int, List[int]] = 256, + n_fourier_feats: int = 4096, + scales: list[float] = [0.01, 0.1, 1, 5, 10, 20, 50, 100], + dropout: float = 0.1, + activation: str = "ReLU", + **kwargs, + ): + """Deep time-index model with meta-learning (DeepTIMe). + + This is an implementation of the DeepTime architecture, as outlined in [1]_. The default arguments + correspond to the hyper-parameters described in the published article. + + This model supports past covariates (known for `input_chunk_length` points before prediction time). + + Parameters + ---------- + input_chunk_length + The length of the input sequence (lookback) fed to the model. + output_chunk_length + The length of the forecast (horizon) of the model. + inr_num_layers + The number of fully connected layers in the INR module. + inr_layers_width + Determines the number of neurons that make up each hidden fully connected layer of the INR module. + If a list is passed, it must have a length equal to `num_layers`. If an integer is passed, + every layers will have the same width. + n_fourier_feats + Number of Fourier components to sample to represent to time-serie in the frequency domain + scales + Scaling factors applied to the normal distribution sampled for Fourier components' magnitude + dropout + The dropout probability to be used in fully connected layers (default=0). This is compatible with + Monte Carlo dropout at inference time for model uncertainty estimation (enabled with + ``mc_dropout=True`` at prediction time). + activation + The activation function of encoder/decoder intermediate layer (default='ReLU'). + Supported activations: ['ReLU','RReLU', 'PReLU', 'Softplus', 'Tanh', 'SELU', 'LeakyReLU', 'Sigmoid'] + **kwargs + Optional arguments to initialize the pytorch_lightning.Module, pytorch_lightning.Trainer, and + Darts' :class:`TorchForecastingModel`. + + loss_fn + PyTorch loss function used for training. + This parameter will be ignored for probabilistic models if the ``likelihood`` parameter is specified. + Default: ``torch.nn.MSELoss()``. + torch_metrics + A torch metric or a ``MetricCollection`` used for evaluation. A full list of available metrics can be found + at https://torchmetrics.readthedocs.io/en/latest/. Default: ``None``. + optimizer_cls + The PyTorch optimizer class to be used. Default: ``torch.optim.Adam``. + optimizer_kwargs + Optionally, some keyword arguments for the PyTorch optimizer (e.g., ``{'lr': 1e-3}`` + for specifying a learning rate). Otherwise the default values of the selected ``optimizer_cls`` + will be used. Default: ``None``. + lr_scheduler_cls + Optionally, the PyTorch learning rate scheduler class to be used. Specifying ``None`` corresponds + to using a constant learning rate. Default: ``None``. + lr_scheduler_kwargs + Optionally, some keyword arguments for the PyTorch learning rate scheduler. Default: ``None``. + batch_size + Number of time series (input and output sequences) used in each training pass. Default: ``32``. + n_epochs + Number of epochs over which to train the model. Default: ``100``. + model_name + Name of the model. Used for creating checkpoints and saving tensorboard data. If not specified, + defaults to the following string ``"YYYY-mm-dd_HH:MM:SS_torch_model_run_PID"``, where the initial part + of the name is formatted with the local date and time, while PID is the processed ID (preventing models + spawned at the same time by different processes to share the same model_name). E.g., + ``"2021-06-14_09:53:32_torch_model_run_44607"``. + work_dir + Path of the working directory, where to save checkpoints and Tensorboard summaries. + Default: current working directory. + log_tensorboard + If set, use Tensorboard to log the different parameters. The logs will be located in: + ``"{work_dir}/darts_logs/{model_name}/logs/"``. Default: ``False``. + nr_epochs_val_period + Number of epochs to wait before evaluating the validation loss (if a validation + ``TimeSeries`` is passed to the :func:`fit()` method). Default: ``1``. + force_reset + If set to ``True``, any previously-existing model with the same name will be reset (all checkpoints will + be discarded). Default: ``False``. + save_checkpoints + Whether or not to automatically save the untrained model and checkpoints from training. + To load the model from checkpoint, call :func:`MyModelClass.load_from_checkpoint()`, where + :class:`MyModelClass` is the :class:`TorchForecastingModel` class that was used (such as :class:`TFTModel`, + :class:`NBEATSModel`, etc.). If set to ``False``, the model can still be manually saved using + :func:`save_model()` and loaded using :func:`load_model()`. Default: ``False``. + add_encoders + A large number of past and future covariates can be automatically generated with `add_encoders`. + This can be done by adding multiple pre-defined index encoders and/or custom user-made functions that + will be used as index encoders. Additionally, a transformer such as Darts' :class:`Scaler` can be added to + transform the generated covariates. This happens all under one hood and only needs to be specified at + model creation. + Read :meth:`SequentialEncoder ` to find out more about + ``add_encoders``. Default: ``None``. An example showing some of ``add_encoders`` features: + + .. highlight:: python + .. code-block:: python + + add_encoders={ + 'cyclic': {'future': ['month']}, + 'datetime_attribute': {'future': ['hour', 'dayofweek']}, + 'position': {'past': ['relative'], 'future': ['relative']}, + 'custom': {'past': [lambda idx: (idx.year - 1950) / 50]}, + 'transformer': Scaler() + } + .. + random_state + Control the randomness of the weights initialization. Check this + `link `_ for more details. + Default: ``None``. + + References + ---------- + .. [1] https://arxiv.org/abs/2207.06046 + """ + super().__init__(**self._extract_torch_model_params(**self.model_params)) + + # extract pytorch lightning module kwargs + self.pl_module_params = self._extract_pl_module_params(**self.model_params) + + raise_if_not( + isinstance(inr_layers_width, int) + or len(inr_layers_width) == inr_num_layers, + "Please pass an integer or a list of integers with length `inr_num_layers`" + "as value for the `inr_layers_width` argument.", + logger, + ) + + raise_if_not( + n_fourier_feats % (2 * len(scales)) == 0, + f"n_fourier_feats: {n_fourier_feats} must be divisible by 2 * len(scales) = {2 * len(scales)}", + logger, + ) + + self.inr_num_layers = inr_num_layers + self.inr_layers_width = inr_layers_width + self.n_fourier_feats = n_fourier_feats + self.scales = scales + + self.dropout = dropout + self.activation = activation + + # TODO: might actually be True? + @staticmethod + def _supports_static_covariates() -> bool: + return False + + def _create_model(self, train_sample: Tuple[torch.Tensor]) -> torch.nn.Module: + input_dim = train_sample[0].shape[1] + ( + train_sample[1].shape[1] if train_sample[1] is not None else 0 + ) + output_dim = train_sample[-1].shape[1] + nr_params = 1 if self.likelihood is None else self.likelihood.num_parameters + + # TODO: support nr_params functionnality + model = _DeepTimeModule( + input_dim=input_dim, + output_dim=output_dim, + nr_params=nr_params, + input_chunk_length=self.input_chunk_length, + output_chunk_length=self.output_chunk_length, + inr_num_layers=self.inr_num_layers, + inr_layers_width=self.inr_layers_width, + n_fourier_feats=self.n_fourier_feats, + scales=self.scales, + dropout=self.dropout, + activation=self.activation, + ) + return model diff --git a/darts/tests/models/forecasting/test_deeptime.py b/darts/tests/models/forecasting/test_deeptime.py new file mode 100644 index 0000000000..17016f1c3e --- /dev/null +++ b/darts/tests/models/forecasting/test_deeptime.py @@ -0,0 +1,185 @@ +import shutil +import tempfile + +import numpy as np + +from darts.logging import get_logger +from darts.tests.base_test_class import DartsBaseTestClass +from darts.utils import timeseries_generation as tg + +logger = get_logger(__name__) + +try: + from darts.models.forecasting.deeptime import DeepTimeModel + + TORCH_AVAILABLE = True +except ImportError: + logger.warning("Torch not available. TCN tests will be skipped.") + TORCH_AVAILABLE = False + + +if TORCH_AVAILABLE: + + class DeepTimeModelTestCase(DartsBaseTestClass): + def setUp(self): + self.temp_work_dir = tempfile.mkdtemp(prefix="darts") + + def tearDown(self): + shutil.rmtree(self.temp_work_dir) + + def test_creation(self): + with self.assertRaises(ValueError): + # if a list is passed to the `layer_widths` argument, it must have a length equal to `num_stacks` + DeepTimeModel( + input_chunk_length=1, + output_chunk_length=1, + inr_num_layers=3, + inr_layers_width=[1], + ) + + def test_fit(self): + large_ts = tg.constant_timeseries(length=100, value=1000) + small_ts = tg.constant_timeseries(length=100, value=10) + + # Test basic fit and predict + model = DeepTimeModel( + input_chunk_length=1, + output_chunk_length=1, + n_epochs=10, + inr_num_layers=2, + inr_layers_width=20, + n_fourier_feats=64, + scales=[0.01, 0.1, 1, 5, 10, 20, 50, 100], + random_state=42, + ) + model.fit(large_ts[:98]) + pred = model.predict(n=2).values()[0] + + # Test whether model trained on one series is better than one trained on another + model2 = DeepTimeModel( + input_chunk_length=1, + output_chunk_length=1, + n_epochs=10, + inr_num_layers=2, + inr_layers_width=20, + n_fourier_feats=64, + scales=[0.01, 0.1, 1, 5, 10, 20, 50, 100], + random_state=42, + ) + model2.fit(small_ts[:98]) + pred2 = model2.predict(n=2).values()[0] + self.assertTrue(abs(pred2 - 10) < abs(pred - 10)) + + # test short predict + pred3 = model2.predict(n=1) + self.assertEqual(len(pred3), 1) + + def test_multivariate(self): + # testing a 2-variate linear ts, first one from 0 to 1, second one from 0 to 0.5, length 100 + series_multivariate = tg.linear_timeseries(length=100).stack( + tg.linear_timeseries(length=100, start_value=0, end_value=0.5) + ) + + model = DeepTimeModel( + input_chunk_length=3, + output_chunk_length=1, + n_epochs=10, + inr_num_layers=5, + inr_layers_width=32, + n_fourier_feats=256, # must have enough fourier components to capture low freq + scales=[0.01, 0.1, 1, 5, 10, 20, 50, 100], + random_state=42, + ) + model.fit(series_multivariate) + res = model.predict(n=2).values() + + # the theoretical result should be [[1.01, 1.02], [0.505, 0.51]]. + # We just test if the given result is not too far on average. + self.assertTrue( + abs(np.average(res - np.array([[1.01, 1.02], [0.505, 0.51]])) < 0.03) + ) + + # Test Covariates + series_covariates = tg.linear_timeseries(length=100).stack( + tg.linear_timeseries(length=100, start_value=0, end_value=0.1) + ) + model = DeepTimeModel( + input_chunk_length=3, + output_chunk_length=4, + n_epochs=10, + inr_num_layers=2, + inr_layers_width=20, + n_fourier_feats=64, + scales=[0.01, 0.1, 1, 5, 10, 20, 50, 100], + random_state=42, + ) + model.fit(series_multivariate, past_covariates=series_covariates) + + res = model.predict( + n=3, series=series_multivariate, past_covariates=series_covariates + ).values() + + self.assertEqual(len(res), 3) + self.assertTrue(abs(np.average(res)) < 5) + + def test_deeptime_n_fourier_feats(self): + with self.assertRaises(ValueError): + # wrong number of scales and n_fourier feats + # n_fourier_feats must be divisiable by 2*len(scales) + DeepTimeModel( + input_chunk_length=1, + output_chunk_length=1, + n_epochs=10, + inr_num_layers=2, + inr_layers_width=20, + n_fourier_feats=17, + scales=[0.01, 0.1, 1, 5, 10, 20, 50, 100], + random_state=42, + ) + + def test_logtensorboard(self): + ts = tg.constant_timeseries(length=50, value=10) + + # Test basic fit and predict + model = DeepTimeModel( + input_chunk_length=1, + output_chunk_length=1, + n_epochs=10, + inr_num_layers=2, + inr_layers_width=20, + n_fourier_feats=64, + scales=[0.01, 0.1, 1, 5, 10, 20, 50, 100], + random_state=42, + ) + model.fit(ts) + model.predict(n=2) + + def test_activation_fns(self): + ts = tg.constant_timeseries(length=50, value=10) + + model = DeepTimeModel( + input_chunk_length=1, + output_chunk_length=1, + n_epochs=10, + inr_num_layers=2, + inr_layers_width=20, + n_fourier_feats=64, + scales=[0.01, 0.1, 1, 5, 10, 20, 50, 100], + activation="LeakyReLU", + random_state=42, + ) + model.fit(ts) + + with self.assertRaises(ValueError): + model = DeepTimeModel( + input_chunk_length=1, + output_chunk_length=1, + n_epochs=10, + inr_num_layers=2, + inr_layers_width=20, + n_fourier_feats=64, + scales=[0.01, 0.1, 1, 5, 10, 20, 50, 100], + activation="invalid", + random_state=42, + ) + model.fit(ts) From 04924721a37c2507c6b4537780dc0aae9f35545d Mon Sep 17 00:00:00 2001 From: madtoinou Date: Thu, 27 Oct 2022 10:04:07 +0200 Subject: [PATCH 02/18] updated the init file --- darts/models/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/darts/models/__init__.py b/darts/models/__init__.py index b9c96a096a..4a1623bbb4 100644 --- a/darts/models/__init__.py +++ b/darts/models/__init__.py @@ -29,13 +29,13 @@ try: from darts.models.forecasting.block_rnn_model import BlockRNNModel + from darts.models.forecasting.deeptime import DeepTimeModel from darts.models.forecasting.nbeats import NBEATSModel from darts.models.forecasting.nhits import NHiTSModel from darts.models.forecasting.rnn_model import RNNModel from darts.models.forecasting.tcn_model import TCNModel from darts.models.forecasting.tft_model import TFTModel from darts.models.forecasting.transformer_model import TransformerModel - except ModuleNotFoundError: logger.warning( "Support for Torch based models not available. " From 4f24cb6290abb6d4975fcffc5787d8665098407a Mon Sep 17 00:00:00 2001 From: madtoinou Date: Thu, 27 Oct 2022 12:48:38 +0200 Subject: [PATCH 03/18] attempt to implement the optimizer described in the original article --- darts/models/forecasting/deeptime.py | 156 ++++++++++++++++-- .../tests/models/forecasting/test_deeptime.py | 8 + 2 files changed, 150 insertions(+), 14 deletions(-) diff --git a/darts/models/forecasting/deeptime.py b/darts/models/forecasting/deeptime.py index 41251c945a..9f8956dc13 100644 --- a/darts/models/forecasting/deeptime.py +++ b/darts/models/forecasting/deeptime.py @@ -206,6 +206,7 @@ def __init__( inr_layers_width: Union[int, List[int]], n_fourier_feats: int, scales: list[float], + legacy_optimiser: bool, dropout: float, activation: str, **kwargs, @@ -228,6 +229,8 @@ def __init__( Number of Fourier components to sample to represent to time-serie in the frequency domain scales Scaling factors applied to the normal distribution sampled for Fourier components' magnitude + legacy_optimiser + Determine if the optimiser policy defined in the original article should be used dropout The dropout probability to be used in fully connected layers (default=0). This is compatible with Monte Carlo dropout at inference time for model uncertainty estimation (enabled with @@ -248,11 +251,11 @@ def __init__( self.inr_layers_width = inr_layers_width self.n_fourier_feats = n_fourier_feats self.scales = scales + self.legacy_optimiser = legacy_optimiser self.dropout = dropout self.activation = activation - # single INR for both univariate and multivariate self.inr = INR( input_size=self.input_chunk_length, num_layers=self.inr_num_layers, @@ -263,9 +266,24 @@ def __init__( activation=self.activation, ) - # lambda_ constructor argument should be connected to torch lightning self.adaptive_weights = RidgeRegressor() + if self.legacy_optimiser: + # define three parameters groups: RidgeRegressor, no decay (bias and norm), decay + ( + optimiser_cls, + optimiser_kwg, + scheduler_cls, + scheduler_kwg, + ) = self.params_legacy_optimizers() + + kwargs["optimizer_cls"] = optimiser_cls + kwargs["optimizer_kwargs"] = optimiser_kwg + kwargs["scheduler_cls"] = optimiser_cls + kwargs["scheduler_kwargs"] = optimiser_kwg + + # TODO: configure_optimiser should be called again... + def forward(self, x_in: Tensor) -> Tensor: x, _ = x_in # x_in: (past_target|past_covariate, static_covariates) @@ -274,15 +292,6 @@ def forward(self, x_in: Tensor) -> Tensor: x.device ) - """ - I don't understand why this code is there... Why would the time have last dim greater than 0?! - if y_time.shape[-1] != 0: - time = torch.cat([x_time, y_time], dim=1) - coords = repeat(coords, '1 t 1 -> b t 1', b=time.shape[0]) - coords = torch.cat([coords, time], dim=-1) - time_reprs = self.inr(coords) - """ - # multivariate if self.output_dim > 1: time = torch.zeros( @@ -318,6 +327,104 @@ def get_coords(self, lookback_len: int, horizon_len: int) -> Tensor: # would it be clearer with the unsqueeze operator? return torch.reshape(coords, (1, lookback_len + horizon_len, 1)) + def params_legacy_optimizers(self): + """Generate arguments needed to configure the orignal article optimizer""" + # we have to create copies because we cannot save model.parameters into object state (not serializable) + optimizer_kws = {k: v for k, v in self.optimizer_kwargs.items()} + + group1 = [] # lambda + group2 = [] # no decay + group3 = [] # decay + no_decay_list = ( + "bias", + "norm", + ) + for param_name, param in self.named_parameters(): + if "_lambda" in param_name: + group1.append(param) + elif any([mod in param_name for mod in no_decay_list]): + group2.append(param) + else: + group3.append(param) + + # parameters for the optimiser + optimiser_class = torch.optim.Adam + optimiser_params = [ + { + "params": group1, + "weight_decay": optimizer_kws["weight_decay"], + "lr": optimizer_kws["lambda_lr"], + }, + {"params": group2, "weight_decay": optimizer_kws["weight_decay"]}, + {"params": group3}, + ] + + # parameters for the scheduler + warmup_epochs = self.scheduler_kwargs["warmup_epochs"] + T_max = self.scheduler_kwargs["Tmax"] + eta_min = 0.0 + scheduler_fns = [] + for param_group, scheduler in zip( + [group1, group2, group3], + [ + "cosine_annealing", + "cosine_annealing_with_linear_warmup", + "cosine_annealing_with_linear_warmup", + ], + ): + lr = eta_max = optimizer_kws["lambda_lr"] + if scheduler == "cosine_annealing": + fn = ( + lambda T_cur: ( + eta_min + + 0.5 + * (eta_max - eta_min) + * ( + 1.0 + + np.cos( + (T_cur - warmup_epochs) + / (T_max - warmup_epochs) + * np.pi + ) + ) + ) + / lr + ) + elif scheduler == "cosine_annealing_with_linear_warmup": + fn = ( + lambda T_cur: T_cur / warmup_epochs + if T_cur < warmup_epochs + else ( + eta_min + + 0.5 + * (eta_max - eta_min) + * ( + 1.0 + + np.cos( + (T_cur - warmup_epochs) + / (T_max - warmup_epochs) + * np.pi + ) + ) + ) + / lr + ) + + scheduler_fns.append(fn) + scheduler_class = torch.optim.lr_scheduler.lambda_lr + scheduler_params = {"lr_lambda": scheduler_fns} + + return optimiser_class, optimiser_params, scheduler_class, scheduler_params + + # TODO: must be implemented since inheriting from PLPastCovariatesModule + """ + def _produce_train_output(): + raise NotImplementedError() + + def _get_batch_prediction(): + raise NotImplementedError() + """ + class DeepTimeModel(PastCovariatesTorchModel): def __init__( @@ -328,6 +435,7 @@ def __init__( inr_layers_width: Union[int, List[int]] = 256, n_fourier_feats: int = 4096, scales: list[float] = [0.01, 0.1, 1, 5, 10, 20, 50, 100], + legacy_optimiser: bool = True, dropout: float = 0.1, activation: str = "ReLU", **kwargs, @@ -355,6 +463,8 @@ def __init__( Number of Fourier components to sample to represent to time-serie in the frequency domain scales Scaling factors applied to the normal distribution sampled for Fourier components' magnitude + legacy_optimiser + Determine if the optimiser policy defined in the original article should be used (default=True). dropout The dropout probability to be used in fully connected layers (default=0). This is compatible with Monte Carlo dropout at inference time for model uncertainty estimation (enabled with @@ -459,11 +569,26 @@ def __init__( f"n_fourier_feats: {n_fourier_feats} must be divisible by 2 * len(scales) = {2 * len(scales)}", logger, ) + raise_if_not( + not legacy_optimiser + or ( + "weight_decay" in self.pl_module_params["optimizer_kwargs"] + and "lambda_lr" in self.pl_module_params["optimizer_kwargs"] + and "lr" in self.pl_module_params["optimizer_kwargs"] + and "Tmax" in self.pl_module_params["scheduler_kwargs"] + and "warmup_epochs" in self.pl_module_params["scheduler_kwargs"] + ), + "At least one argument is missing to use the legacy optimiser. `weight_decay`, " + "`lambda_lr` and `lr` must be defined in `optimizer_kwargs` whereas `Tmax` and " + "'warmup_epochs must be defined in `scheduler_kwargs`.", + logger, + ) self.inr_num_layers = inr_num_layers self.inr_layers_width = inr_layers_width self.n_fourier_feats = n_fourier_feats self.scales = scales + self.legacy_optimiser = legacy_optimiser self.dropout = dropout self.activation = activation @@ -473,7 +598,10 @@ def __init__( def _supports_static_covariates() -> bool: return False - def _create_model(self, train_sample: Tuple[torch.Tensor]) -> torch.nn.Module: + def _create_model( + self, + train_sample: Tuple[torch.Tensor], + ) -> torch.nn.Module: input_dim = train_sample[0].shape[1] + ( train_sample[1].shape[1] if train_sample[1] is not None else 0 ) @@ -485,13 +613,13 @@ def _create_model(self, train_sample: Tuple[torch.Tensor]) -> torch.nn.Module: input_dim=input_dim, output_dim=output_dim, nr_params=nr_params, - input_chunk_length=self.input_chunk_length, - output_chunk_length=self.output_chunk_length, inr_num_layers=self.inr_num_layers, inr_layers_width=self.inr_layers_width, n_fourier_feats=self.n_fourier_feats, scales=self.scales, + legacy_optimiser=self.legacy_optimiser, dropout=self.dropout, activation=self.activation, + **self.pl_module_params, ) return model diff --git a/darts/tests/models/forecasting/test_deeptime.py b/darts/tests/models/forecasting/test_deeptime.py index 17016f1c3e..4861ecbeaa 100644 --- a/darts/tests/models/forecasting/test_deeptime.py +++ b/darts/tests/models/forecasting/test_deeptime.py @@ -50,6 +50,7 @@ def test_fit(self): inr_layers_width=20, n_fourier_feats=64, scales=[0.01, 0.1, 1, 5, 10, 20, 50, 100], + legacy_optimiser=False, random_state=42, ) model.fit(large_ts[:98]) @@ -64,6 +65,7 @@ def test_fit(self): inr_layers_width=20, n_fourier_feats=64, scales=[0.01, 0.1, 1, 5, 10, 20, 50, 100], + legacy_optimiser=False, random_state=42, ) model2.fit(small_ts[:98]) @@ -88,6 +90,7 @@ def test_multivariate(self): inr_layers_width=32, n_fourier_feats=256, # must have enough fourier components to capture low freq scales=[0.01, 0.1, 1, 5, 10, 20, 50, 100], + legacy_optimiser=False, random_state=42, ) model.fit(series_multivariate) @@ -111,6 +114,7 @@ def test_multivariate(self): inr_layers_width=20, n_fourier_feats=64, scales=[0.01, 0.1, 1, 5, 10, 20, 50, 100], + legacy_optimiser=False, random_state=42, ) model.fit(series_multivariate, past_covariates=series_covariates) @@ -134,6 +138,7 @@ def test_deeptime_n_fourier_feats(self): inr_layers_width=20, n_fourier_feats=17, scales=[0.01, 0.1, 1, 5, 10, 20, 50, 100], + legacy_optimiser=False, random_state=42, ) @@ -149,6 +154,7 @@ def test_logtensorboard(self): inr_layers_width=20, n_fourier_feats=64, scales=[0.01, 0.1, 1, 5, 10, 20, 50, 100], + legacy_optimiser=False, random_state=42, ) model.fit(ts) @@ -165,6 +171,7 @@ def test_activation_fns(self): inr_layers_width=20, n_fourier_feats=64, scales=[0.01, 0.1, 1, 5, 10, 20, 50, 100], + legacy_optimiser=False, activation="LeakyReLU", random_state=42, ) @@ -179,6 +186,7 @@ def test_activation_fns(self): inr_layers_width=20, n_fourier_feats=64, scales=[0.01, 0.1, 1, 5, 10, 20, 50, 100], + legacy_optimiser=False, activation="invalid", random_state=42, ) From 60353366f121887a6ce25241644555e62924e6d3 Mon Sep 17 00:00:00 2001 From: madtoinou Date: Thu, 27 Oct 2022 15:28:11 +0200 Subject: [PATCH 04/18] changed the approach for optimizer definition, now overriding the configure_optimizer method --- darts/models/forecasting/deeptime.py | 215 ++++++++++-------- .../tests/models/forecasting/test_deeptime.py | 18 +- 2 files changed, 129 insertions(+), 104 deletions(-) diff --git a/darts/models/forecasting/deeptime.py b/darts/models/forecasting/deeptime.py index 9f8956dc13..f3811a3c3d 100644 --- a/darts/models/forecasting/deeptime.py +++ b/darts/models/forecasting/deeptime.py @@ -268,22 +268,6 @@ def __init__( self.adaptive_weights = RidgeRegressor() - if self.legacy_optimiser: - # define three parameters groups: RidgeRegressor, no decay (bias and norm), decay - ( - optimiser_cls, - optimiser_kwg, - scheduler_cls, - scheduler_kwg, - ) = self.params_legacy_optimizers() - - kwargs["optimizer_cls"] = optimiser_cls - kwargs["optimizer_kwargs"] = optimiser_kwg - kwargs["scheduler_cls"] = optimiser_cls - kwargs["scheduler_kwargs"] = optimiser_kwg - - # TODO: configure_optimiser should be called again... - def forward(self, x_in: Tensor) -> Tensor: x, _ = x_in # x_in: (past_target|past_covariate, static_covariates) @@ -327,13 +311,13 @@ def get_coords(self, lookback_len: int, horizon_len: int) -> Tensor: # would it be clearer with the unsqueeze operator? return torch.reshape(coords, (1, lookback_len + horizon_len, 1)) - def params_legacy_optimizers(self): - """Generate arguments needed to configure the orignal article optimizer""" + def configure_optimizers(self): # we have to create copies because we cannot save model.parameters into object state (not serializable) optimizer_kws = {k: v for k, v in self.optimizer_kwargs.items()} - group1 = [] # lambda - group2 = [] # no decay + # define three parameters groups + group1 = [] # lambda (RidgeRegressor) + group2 = [] # no decay (bias and norm) group3 = [] # decay no_decay_list = ( "bias", @@ -346,75 +330,77 @@ def params_legacy_optimizers(self): group2.append(param) else: group3.append(param) - - # parameters for the optimiser - optimiser_class = torch.optim.Adam - optimiser_params = [ - { - "params": group1, - "weight_decay": optimizer_kws["weight_decay"], - "lr": optimizer_kws["lambda_lr"], - }, - {"params": group2, "weight_decay": optimizer_kws["weight_decay"]}, - {"params": group3}, - ] - - # parameters for the scheduler - warmup_epochs = self.scheduler_kwargs["warmup_epochs"] - T_max = self.scheduler_kwargs["Tmax"] - eta_min = 0.0 - scheduler_fns = [] - for param_group, scheduler in zip( - [group1, group2, group3], + optimizer = torch.optim.Adam( [ - "cosine_annealing", - "cosine_annealing_with_linear_warmup", - "cosine_annealing_with_linear_warmup", + {"params": group1, "weight_decay": 0, "lr": optimizer_kws["lambda_lr"]}, + {"params": group2, "weight_decay": 0}, + {"params": group3}, ], - ): - lr = eta_max = optimizer_kws["lambda_lr"] - if scheduler == "cosine_annealing": - fn = ( - lambda T_cur: ( - eta_min - + 0.5 - * (eta_max - eta_min) - * ( - 1.0 - + np.cos( - (T_cur - warmup_epochs) - / (T_max - warmup_epochs) - * np.pi - ) - ) - ) - / lr + lr=optimizer_kws["lr"], + weight_decay=optimizer_kws["weight_decay"], + ) + + # define a scheduler for each optimizer + lr_sched_kws = {k: v for k, v in self.lr_scheduler_kwargs.items()} + + warmup_epochs = lr_sched_kws["warmup_epochs"] + T_max = lr_sched_kws["T_max"] + eta_min = lr_sched_kws["eta_min"] + scheduler_fns = [] + + def no_scheduler(T_cur): + return 1 + + def cosine_annealing(T_cur): + return ( + eta_min + + 0.5 + * (eta_max - eta_min) + * ( + 1.0 + + np.cos((T_cur - warmup_epochs) / (T_max - warmup_epochs) * np.pi) ) - elif scheduler == "cosine_annealing_with_linear_warmup": - fn = ( - lambda T_cur: T_cur / warmup_epochs - if T_cur < warmup_epochs - else ( - eta_min - + 0.5 - * (eta_max - eta_min) - * ( - 1.0 - + np.cos( - (T_cur - warmup_epochs) - / (T_max - warmup_epochs) - * np.pi - ) + ) / lr + + def cosine_annealing_with_linear_warmup(T_cur): + if T_cur < warmup_epochs: + return T_cur / warmup_epochs + else: + return ( + eta_min + + 0.5 + * (eta_max - eta_min) + * ( + 1.0 + + np.cos( + (T_cur - warmup_epochs) / (T_max - warmup_epochs) * np.pi ) ) - / lr - ) + ) / lr + for param_group, scheduler in zip( + optimizer.param_groups, lr_sched_kws["scheduler_names"] + ): + if scheduler == "none": + fn = no_scheduler + elif scheduler == "cosine_annealing": + lr = eta_max = param_group["lr"] + fn = cosine_annealing + elif scheduler == "cosine_annealing_with_linear_warmup": + lr = eta_max = param_group["lr"] + fn = cosine_annealing_with_linear_warmup + else: + raise ValueError(f"No such scheduler, {scheduler}") scheduler_fns.append(fn) - scheduler_class = torch.optim.lr_scheduler.lambda_lr - scheduler_params = {"lr_lambda": scheduler_fns} - return optimiser_class, optimiser_params, scheduler_class, scheduler_params + lr_scheduler = torch.optim.lr_scheduler.LambdaLR( + optimizer, lr_lambda=scheduler_fns + ) + + return [optimizer], { + "scheduler": lr_scheduler, + "monitor": "val_loss", + } # TODO: must be implemented since inheriting from PLPastCovariatesModule """ @@ -569,20 +555,59 @@ def __init__( f"n_fourier_feats: {n_fourier_feats} must be divisible by 2 * len(scales) = {2 * len(scales)}", logger, ) - raise_if_not( - not legacy_optimiser - or ( - "weight_decay" in self.pl_module_params["optimizer_kwargs"] - and "lambda_lr" in self.pl_module_params["optimizer_kwargs"] - and "lr" in self.pl_module_params["optimizer_kwargs"] - and "Tmax" in self.pl_module_params["scheduler_kwargs"] - and "warmup_epochs" in self.pl_module_params["scheduler_kwargs"] - ), - "At least one argument is missing to use the legacy optimiser. `weight_decay`, " - "`lambda_lr` and `lr` must be defined in `optimizer_kwargs` whereas `Tmax` and " - "'warmup_epochs must be defined in `scheduler_kwargs`.", - logger, - ) + + if legacy_optimiser: + self.pl_module_params["optimizer_kwargs"] = { + "lr": 1e-3, + "lambda_lr": 1.0, + "weight_decay": 0.0, + } + + self.pl_module_params["lr_scheduler_kwargs"] = { + "warmup_epochs": 5, + "T_max": self.n_epochs, + "eta_min": 0.0, + "scheduler_names": [ + "cosine_annealing", + "cosine_annealing_with_linear_warmup", + "cosine_annealing_with_linear_warmup", + ], + } + else: + raise_if_not( + "optimizer_kwargs" in self.pl_module_params.keys(), + "`optimizer_kwargs` should contain the following arguments: `weight_decay`, " + "`lambda_lr` and `lr` in order to create the optimizer.", + logger, + ) + + raise_if_not( + "lr_scheduler_kwargs" in self.pl_module_params.keys(), + "`lr_scheduler_kwargs` should contain the following arguments: `Tmax` and " + "`warmup_epochs` in order to create the optimizer.", + logger, + ) + + expected_params = { + "weight_decay", + "lambda_lr", + "lr", + "T_max", + "warmup_epochs", + "eta_min", + "scheduler_names", + } + optimizer_params = self.pl_module_params["optimizer_kwargs"].keys() + scheduler_params = self.pl_module_params["lr_scheduler_kwargs"].keys() + provided_params = set(optimizer_params).union(set(scheduler_params)) + missing_params = expected_params - provided_params + raise_if_not( + not legacy_optimiser and len(missing_params) == 0, + f"Missing argument(s) for the optimiser: {missing_params}. `weight_decay`, " + "`lambda_lr` and `lr` must be defined in `optimizer_kwargs` whereas `Tmax`, " + "`scheduler_names` and `warmup_epochs` must be defined in `scheduler_kwargs`.", + logger, + ) self.inr_num_layers = inr_num_layers self.inr_layers_width = inr_layers_width diff --git a/darts/tests/models/forecasting/test_deeptime.py b/darts/tests/models/forecasting/test_deeptime.py index 4861ecbeaa..689bc8cfaa 100644 --- a/darts/tests/models/forecasting/test_deeptime.py +++ b/darts/tests/models/forecasting/test_deeptime.py @@ -50,7 +50,7 @@ def test_fit(self): inr_layers_width=20, n_fourier_feats=64, scales=[0.01, 0.1, 1, 5, 10, 20, 50, 100], - legacy_optimiser=False, + legacy_optimiser=True, random_state=42, ) model.fit(large_ts[:98]) @@ -65,7 +65,7 @@ def test_fit(self): inr_layers_width=20, n_fourier_feats=64, scales=[0.01, 0.1, 1, 5, 10, 20, 50, 100], - legacy_optimiser=False, + legacy_optimiser=True, random_state=42, ) model2.fit(small_ts[:98]) @@ -87,10 +87,10 @@ def test_multivariate(self): output_chunk_length=1, n_epochs=10, inr_num_layers=5, - inr_layers_width=32, + inr_layers_width=64, n_fourier_feats=256, # must have enough fourier components to capture low freq scales=[0.01, 0.1, 1, 5, 10, 20, 50, 100], - legacy_optimiser=False, + legacy_optimiser=True, random_state=42, ) model.fit(series_multivariate) @@ -114,7 +114,7 @@ def test_multivariate(self): inr_layers_width=20, n_fourier_feats=64, scales=[0.01, 0.1, 1, 5, 10, 20, 50, 100], - legacy_optimiser=False, + legacy_optimiser=True, random_state=42, ) model.fit(series_multivariate, past_covariates=series_covariates) @@ -138,7 +138,7 @@ def test_deeptime_n_fourier_feats(self): inr_layers_width=20, n_fourier_feats=17, scales=[0.01, 0.1, 1, 5, 10, 20, 50, 100], - legacy_optimiser=False, + legacy_optimiser=True, random_state=42, ) @@ -154,7 +154,7 @@ def test_logtensorboard(self): inr_layers_width=20, n_fourier_feats=64, scales=[0.01, 0.1, 1, 5, 10, 20, 50, 100], - legacy_optimiser=False, + legacy_optimiser=True, random_state=42, ) model.fit(ts) @@ -171,7 +171,7 @@ def test_activation_fns(self): inr_layers_width=20, n_fourier_feats=64, scales=[0.01, 0.1, 1, 5, 10, 20, 50, 100], - legacy_optimiser=False, + legacy_optimiser=True, activation="LeakyReLU", random_state=42, ) @@ -186,7 +186,7 @@ def test_activation_fns(self): inr_layers_width=20, n_fourier_feats=64, scales=[0.01, 0.1, 1, 5, 10, 20, 50, 100], - legacy_optimiser=False, + legacy_optimiser=True, activation="invalid", random_state=42, ) From 3db5d1474252efaa7e23a38391af0aaa36843ead Mon Sep 17 00:00:00 2001 From: madtoinou Date: Thu, 27 Oct 2022 15:31:08 +0200 Subject: [PATCH 05/18] corrected small typos in the documentation --- darts/dataprocessing/transformers/scaler.py | 2 +- darts/utils/data/training_dataset.py | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/darts/dataprocessing/transformers/scaler.py b/darts/dataprocessing/transformers/scaler.py index 613c71662b..502e9523a1 100644 --- a/darts/dataprocessing/transformers/scaler.py +++ b/darts/dataprocessing/transformers/scaler.py @@ -73,7 +73,7 @@ def __init__( >>> print(min(series_transformed.values())) [-1.] >>> print(max(series_transformed.values())) - [2.] + [1.] """ super().__init__(name=name, n_jobs=n_jobs, verbose=verbose) diff --git a/darts/utils/data/training_dataset.py b/darts/utils/data/training_dataset.py index d485ee6159..8147d7f3af 100644 --- a/darts/utils/data/training_dataset.py +++ b/darts/utils/data/training_dataset.py @@ -202,7 +202,7 @@ def _memory_indexer( class PastCovariatesTrainingDataset(TrainingDataset, ABC): def __init__(self): """ - Abstract class for a PastCovariatesTorchModel training dataset. It contains 3-tuples of + Abstract class for a PastCovariatesTorchModel training dataset. It contains 4-tuples of `(past_target, past_covariate, static_covariates, future_target)` `np.ndarray`. The covariates are optional and can be `None`. """ @@ -218,7 +218,7 @@ def __getitem__( class FutureCovariatesTrainingDataset(TrainingDataset, ABC): def __init__(self): """ - Abstract class for a FutureCovariatesTorchModel training dataset. It contains 3-tuples of + Abstract class for a FutureCovariatesTorchModel training dataset. It contains 4-tuples of `(past_target, future_covariate, static_covariates, future_target)` `np.ndarray`. The covariates are optional and can be `None`. """ @@ -234,7 +234,7 @@ def __getitem__( class DualCovariatesTrainingDataset(TrainingDataset, ABC): def __init__(self): """ - Abstract class for a DualCovariatesTorchModel training dataset. It contains 4-tuples of + Abstract class for a DualCovariatesTorchModel training dataset. It contains 5-tuples of `(past_target, historic_future_covariates, future_covariates, static_covariates, future_target)` `np.ndarray`. The covariates are optional and can be `None`. """ @@ -256,7 +256,7 @@ def __getitem__( class MixedCovariatesTrainingDataset(TrainingDataset, ABC): def __init__(self): """ - Abstract class for a MixedCovariatesTorchModel training dataset. It contains 5-tuples of + Abstract class for a MixedCovariatesTorchModel training dataset. It contains 6-tuples of `(past_target, past_covariates, historic_future_covariates, future_covariates, static_covariates, future_target)` `np.ndarray`. The covariates are optional and can be `None`. @@ -280,7 +280,7 @@ def __getitem__( class SplitCovariatesTrainingDataset(TrainingDataset, ABC): def __init__(self): """ - Abstract class for a SplitCovariatesTorchModel training dataset. It contains 4-tuples of + Abstract class for a SplitCovariatesTorchModel training dataset. It contains 5-tuples of `(past_target, past_covariates, future_covariates, static_covariates, future_target)` `np.ndarray`. The covariates are optional and can be `None`. """ From d52f231a70d9be33e79358307d2f153e89031fab Mon Sep 17 00:00:00 2001 From: madtoinou Date: Thu, 27 Oct 2022 15:45:22 +0200 Subject: [PATCH 06/18] improved variable naming for the schedulers --- darts/models/forecasting/deeptime.py | 35 ++++++++++++++++++---------- 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/darts/models/forecasting/deeptime.py b/darts/models/forecasting/deeptime.py index f3811a3c3d..a7a6846462 100644 --- a/darts/models/forecasting/deeptime.py +++ b/darts/models/forecasting/deeptime.py @@ -343,28 +343,32 @@ def configure_optimizers(self): # define a scheduler for each optimizer lr_sched_kws = {k: v for k, v in self.lr_scheduler_kwargs.items()} + total_epochs = lr_sched_kws["total_epochs"] warmup_epochs = lr_sched_kws["warmup_epochs"] - T_max = lr_sched_kws["T_max"] eta_min = lr_sched_kws["eta_min"] scheduler_fns = [] - def no_scheduler(T_cur): + def no_scheduler(current_epoch): return 1 - def cosine_annealing(T_cur): + def cosine_annealing(current_epoch): return ( eta_min + 0.5 * (eta_max - eta_min) * ( 1.0 - + np.cos((T_cur - warmup_epochs) / (T_max - warmup_epochs) * np.pi) + + np.cos( + (current_epoch - warmup_epochs) + / (total_epochs - warmup_epochs) + * np.pi + ) ) ) / lr - def cosine_annealing_with_linear_warmup(T_cur): - if T_cur < warmup_epochs: - return T_cur / warmup_epochs + def cosine_annealing_with_linear_warmup(current_epoch): + if current_epoch < warmup_epochs: + return current_epoch / warmup_epochs else: return ( eta_min @@ -373,7 +377,9 @@ def cosine_annealing_with_linear_warmup(T_cur): * ( 1.0 + np.cos( - (T_cur - warmup_epochs) / (T_max - warmup_epochs) * np.pi + (current_epoch - warmup_epochs) + / (total_epochs - warmup_epochs) + * np.pi ) ) ) / lr @@ -563,9 +569,15 @@ def __init__( "weight_decay": 0.0, } + raise_if_not( + self.n_epochs > 0, + "`self.n_epochs` should be greater than 0, it is used to declare the learning rate scheduler.", + logger, + ) + self.pl_module_params["lr_scheduler_kwargs"] = { "warmup_epochs": 5, - "T_max": self.n_epochs, + "total_epochs": self.n_epochs, "eta_min": 0.0, "scheduler_names": [ "cosine_annealing", @@ -583,8 +595,8 @@ def __init__( raise_if_not( "lr_scheduler_kwargs" in self.pl_module_params.keys(), - "`lr_scheduler_kwargs` should contain the following arguments: `Tmax` and " - "`warmup_epochs` in order to create the optimizer.", + "`lr_scheduler_kwargs` should contain the following arguments: `eta_min`, " + "`warmup_epochs` and `scheduler_names` in order to create the optimizer.", logger, ) @@ -592,7 +604,6 @@ def __init__( "weight_decay", "lambda_lr", "lr", - "T_max", "warmup_epochs", "eta_min", "scheduler_names", From e13ed578d371a399f4d577768c69b4d81e4330e4 Mon Sep 17 00:00:00 2001 From: madtoinou Date: Thu, 27 Oct 2022 19:18:21 +0200 Subject: [PATCH 07/18] simplified the model implementation based on the article pseudo code and taking advantage of the safeguards offered by darts --- darts/models/forecasting/deeptime.py | 43 +++++-------------- .../tests/models/forecasting/test_deeptime.py | 15 ++++--- 2 files changed, 19 insertions(+), 39 deletions(-) diff --git a/darts/models/forecasting/deeptime.py b/darts/models/forecasting/deeptime.py index a7a6846462..90a8c71351 100644 --- a/darts/models/forecasting/deeptime.py +++ b/darts/models/forecasting/deeptime.py @@ -57,12 +57,6 @@ def forward(self, x: Tensor) -> Tensor: f"Expected 2 or more dimensional input (got {x.dim()}D input)", logger, ) - dim = x.shape[-1] - raise_if_not( - dim == self.input_dim, - f"Expected input to have {self.input_dim} channels (got {dim} channels)", - logger, - ) x = torch.einsum("... t n, n d -> ... t d", [x, self.B]) x = 2 * np.pi * x @@ -72,7 +66,7 @@ def forward(self, x: Tensor) -> Tensor: class INR(nn.Module): def __init__( self, - input_size: int, + input_dim: int, num_layers: int, hidden_layers_width: int, n_fourier_feats: int, @@ -86,7 +80,7 @@ def __init__( Parameters ---------- - input_size + input_dim The dimensionality of the input time series. num_layers The number of fully connected layers. @@ -115,7 +109,7 @@ def __init__( """ super().__init__() - self.input_size = input_size + self.input_dim = input_dim self.num_layers = num_layers self.n_fourier_feats = n_fourier_feats self.scales = scales @@ -133,11 +127,11 @@ def __init__( if n_fourier_feats == 0: feats_size = self.hidden_layers_width[0] - self.features = nn.Linear(self.input_size, feats_size) + self.features = nn.Linear(self.input_dim, feats_size) else: feats_size = self.n_fourier_feats self.features = GaussianFourierFeatureTransform( - self.input_size, feats_size, self.scales + self.input_dim, feats_size, self.scales ) # Fully Connected Network @@ -257,7 +251,7 @@ def __init__( self.activation = activation self.inr = INR( - input_size=self.input_chunk_length, + input_dim=self.input_dim + 1, num_layers=self.inr_num_layers, hidden_layers_width=self.inr_layers_width, n_fourier_feats=self.n_fourier_feats, @@ -271,27 +265,11 @@ def __init__( def forward(self, x_in: Tensor) -> Tensor: x, _ = x_in # x_in: (past_target|past_covariate, static_covariates) - batch_size, _, in_dim = x.shape # x: (batch, in_len, in_dim) - coords = self.get_coords(self.input_chunk_length, self.output_chunk_length).to( - x.device - ) + batch_size, _, _ = x.shape # x: (batch_size, in_len, in_dim) + coords = self.get_coords(self.input_chunk_length, self.output_chunk_length) - # multivariate - if self.output_dim > 1: - time = torch.zeros( - size=( - batch_size, - self.input_chunk_length + self.output_chunk_length, - self.output_dim, - ) - ).to(x.device) - coords = coords.repeat(batch_size, 1, 1) - coords = torch.cat([coords, time], dim=-1) - time_reprs = self.inr(coords) - # univariate - else: - time_reprs = self.inr(coords) - time_reprs = time_reprs.repeat(batch_size, 1, 1) + time_reprs = self.inr(coords) + time_reprs = time_reprs.repeat(batch_size, 1, self.input_dim) lookback_reprs = time_reprs[:, : -self.output_chunk_length] horizon_reprs = time_reprs[:, -self.output_chunk_length :] @@ -300,6 +278,7 @@ def forward(self, x_in: Tensor) -> Tensor: # apply weights to the horizon y = torch.einsum("... d o, ... t d -> ... t o", [w, horizon_reprs]) + b + # taken from nbeats, could be simplified y = y.view( y.shape[0], self.output_chunk_length, self.input_dim, self.nr_params )[:, :, : self.output_dim, :] diff --git a/darts/tests/models/forecasting/test_deeptime.py b/darts/tests/models/forecasting/test_deeptime.py index 689bc8cfaa..e218a3a06b 100644 --- a/darts/tests/models/forecasting/test_deeptime.py +++ b/darts/tests/models/forecasting/test_deeptime.py @@ -31,7 +31,7 @@ def test_creation(self): with self.assertRaises(ValueError): # if a list is passed to the `layer_widths` argument, it must have a length equal to `num_stacks` DeepTimeModel( - input_chunk_length=1, + input_chunk_length=2, output_chunk_length=1, inr_num_layers=3, inr_layers_width=[1], @@ -43,7 +43,7 @@ def test_fit(self): # Test basic fit and predict model = DeepTimeModel( - input_chunk_length=1, + input_chunk_length=2, output_chunk_length=1, n_epochs=10, inr_num_layers=2, @@ -58,7 +58,7 @@ def test_fit(self): # Test whether model trained on one series is better than one trained on another model2 = DeepTimeModel( - input_chunk_length=1, + input_chunk_length=2, output_chunk_length=1, n_epochs=10, inr_num_layers=2, @@ -98,6 +98,7 @@ def test_multivariate(self): # the theoretical result should be [[1.01, 1.02], [0.505, 0.51]]. # We just test if the given result is not too far on average. + print(res) self.assertTrue( abs(np.average(res - np.array([[1.01, 1.02], [0.505, 0.51]])) < 0.03) ) @@ -131,7 +132,7 @@ def test_deeptime_n_fourier_feats(self): # wrong number of scales and n_fourier feats # n_fourier_feats must be divisiable by 2*len(scales) DeepTimeModel( - input_chunk_length=1, + input_chunk_length=2, output_chunk_length=1, n_epochs=10, inr_num_layers=2, @@ -147,7 +148,7 @@ def test_logtensorboard(self): # Test basic fit and predict model = DeepTimeModel( - input_chunk_length=1, + input_chunk_length=2, output_chunk_length=1, n_epochs=10, inr_num_layers=2, @@ -164,7 +165,7 @@ def test_activation_fns(self): ts = tg.constant_timeseries(length=50, value=10) model = DeepTimeModel( - input_chunk_length=1, + input_chunk_length=2, output_chunk_length=1, n_epochs=10, inr_num_layers=2, @@ -179,7 +180,7 @@ def test_activation_fns(self): with self.assertRaises(ValueError): model = DeepTimeModel( - input_chunk_length=1, + input_chunk_length=2, output_chunk_length=1, n_epochs=10, inr_num_layers=2, From 00e7f18bd7942492c134020f9b13f6d30ed994cc Mon Sep 17 00:00:00 2001 From: madtoinou Date: Thu, 27 Oct 2022 19:21:53 +0200 Subject: [PATCH 08/18] reverted input_length_chunck to values identical to the nbeats tests --- darts/tests/models/forecasting/test_deeptime.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/darts/tests/models/forecasting/test_deeptime.py b/darts/tests/models/forecasting/test_deeptime.py index e218a3a06b..facaffd1b9 100644 --- a/darts/tests/models/forecasting/test_deeptime.py +++ b/darts/tests/models/forecasting/test_deeptime.py @@ -31,7 +31,7 @@ def test_creation(self): with self.assertRaises(ValueError): # if a list is passed to the `layer_widths` argument, it must have a length equal to `num_stacks` DeepTimeModel( - input_chunk_length=2, + input_chunk_length=1, output_chunk_length=1, inr_num_layers=3, inr_layers_width=[1], @@ -43,7 +43,7 @@ def test_fit(self): # Test basic fit and predict model = DeepTimeModel( - input_chunk_length=2, + input_chunk_length=1, output_chunk_length=1, n_epochs=10, inr_num_layers=2, @@ -58,7 +58,7 @@ def test_fit(self): # Test whether model trained on one series is better than one trained on another model2 = DeepTimeModel( - input_chunk_length=2, + input_chunk_length=1, output_chunk_length=1, n_epochs=10, inr_num_layers=2, @@ -132,7 +132,7 @@ def test_deeptime_n_fourier_feats(self): # wrong number of scales and n_fourier feats # n_fourier_feats must be divisiable by 2*len(scales) DeepTimeModel( - input_chunk_length=2, + input_chunk_length=1, output_chunk_length=1, n_epochs=10, inr_num_layers=2, @@ -148,7 +148,7 @@ def test_logtensorboard(self): # Test basic fit and predict model = DeepTimeModel( - input_chunk_length=2, + input_chunk_length=1, output_chunk_length=1, n_epochs=10, inr_num_layers=2, @@ -165,7 +165,7 @@ def test_activation_fns(self): ts = tg.constant_timeseries(length=50, value=10) model = DeepTimeModel( - input_chunk_length=2, + input_chunk_length=1, output_chunk_length=1, n_epochs=10, inr_num_layers=2, @@ -180,7 +180,7 @@ def test_activation_fns(self): with self.assertRaises(ValueError): model = DeepTimeModel( - input_chunk_length=2, + input_chunk_length=1, output_chunk_length=1, n_epochs=10, inr_num_layers=2, From 99034353d42ab39b9b8e7992e4dfb06af00a520d Mon Sep 17 00:00:00 2001 From: madtoinou Date: Fri, 28 Oct 2022 09:57:08 +0200 Subject: [PATCH 09/18] improved the docstring --- darts/models/forecasting/deeptime.py | 93 ++++++++++++++++++---------- 1 file changed, 59 insertions(+), 34 deletions(-) diff --git a/darts/models/forecasting/deeptime.py b/darts/models/forecasting/deeptime.py index 90a8c71351..38e26f9cd5 100644 --- a/darts/models/forecasting/deeptime.py +++ b/darts/models/forecasting/deeptime.py @@ -34,12 +34,31 @@ class GaussianFourierFeatureTransform(nn.Module): - """ - https://github.com/ndahlquist/pytorch-fourier-feature-networks - Given an input of size [..., time, dim], returns a tensor of size [..., n_fourier_feats, time]. - """ - def __init__(self, input_dim: int, n_fourier_feats: int, scales: List[float]): + """ + Implementation of the Gaussian Fourier features mapping. + https://arxiv.org/abs/2006.10739 + https://github.com/ndahlquist/pytorch-fourier-feature-networks + + Parameters + ---------- + input_dim + The dimensionality of the input time series. + n_fourier_feats + Number of Fourier components to sample to represent to time-serie in the frequency domain + scales + Scaling factors applied to the normal distribution sampled for Fourier components' magnitude + + Inputs + ------ + x of shape `(1, input_chunk_length+output_chunk_length, 1)` + Tensor containing the [0,1] normalised time representation. + + Outputs + ------- + y of shape `(1, input_chunk_length+output_chunk_length, n_fourier_feats)` + Tensor containing the Gaussian Fourier features for the processed period. + """ super().__init__() self.input_dim = input_dim self.n_fourier_feats = n_fourier_feats @@ -57,7 +76,6 @@ def forward(self, x: Tensor) -> Tensor: f"Expected 2 or more dimensional input (got {x.dim()}D input)", logger, ) - x = torch.einsum("... t n, n d -> ... t d", [x, self.B]) x = 2 * np.pi * x return torch.cat([torch.sin(x), torch.cos(x)], dim=-1) @@ -79,33 +97,33 @@ def __init__( Features can be encoded using either a Linear layer or a Gaussian Fourier Transform Parameters - ---------- - input_dim - The dimensionality of the input time series. - num_layers - The number of fully connected layers. - hidden_layers_width - Determines the number of neurons that make up each hidden fully connected layer. - If a list is passed, it must have a length equal to `num_layers`. If an integer is passed, - every layers will have the same width. - n_fourier_feats - Number of Fourier components to sample to represent to time-serie in the frequency domain - scales - Scaling factors applied to the normal distribution sampled for Fourier components' magnitude - dropout - The fraction of neurons that are dropped at each layer. - activation - The activation function of fully connected network intermediate layers. - - Inputs - ------ - x of shape `(batch_size, input_size)` - Tensor containing the features of the input sequence. - - Outputs - ------- - y of shape `(batch_size, output_chunk_length, target_size, nr_params)` - Tensor containing the prediction at the last time step of the sequence. + ---------- + input_dim + The dimensionality of the input time series. + num_layers + The number of fully connected layers. + hidden_layers_width + Determines the number of neurons that make up each hidden fully connected layer. + If a list is passed, it must have a length equal to `num_layers`. If an integer is passed, + every layers will have the same width. + n_fourier_feats + Number of Fourier components to sample to represent to time-serie in the frequency domain + scales + Scaling factors applied to the normal distribution sampled for Fourier components' magnitude + dropout + The fraction of neurons that are dropped at each layer. + activation + The activation function of fully connected network intermediate layers. + + Inputs + ------ + x of shape `(1, input_chunk_length+output_chunk_length, 1)` + Tensor containing the [0,1] normalised time representation. + + Outputs + ------- + y of shape `(1, input_chunk_length+output_chunk_length, hidden_layers_width)` + Tensor containing the implicit neural representation of the time. """ super().__init__() @@ -158,6 +176,7 @@ def forward(self, x: Tensor) -> Tensor: class RidgeRegressor(nn.Module): def __init__(self, lambda_init: float = 0.0): + """Implementation of the closed form Ridge Regression with a regularization coefficient.""" super().__init__() self._lambda = nn.Parameter(torch.as_tensor(lambda_init, dtype=torch.float)) @@ -291,6 +310,10 @@ def get_coords(self, lookback_len: int, horizon_len: int) -> Tensor: return torch.reshape(coords, (1, lookback_len + horizon_len, 1)) def configure_optimizers(self): + """Override the configure_optimizers to define three groups of parameters, one for the + Ridge Regression weights, one for the biais and norm of the FCN and another of the other + weights of the FCN. + """ # we have to create copies because we cannot save model.parameters into object state (not serializable) optimizer_kws = {k: v for k, v in self.optimizer_kwargs.items()} @@ -435,7 +458,9 @@ def __init__( scales Scaling factors applied to the normal distribution sampled for Fourier components' magnitude legacy_optimiser - Determine if the optimiser policy defined in the original article should be used (default=True). + Determine if the optimiser described in the original article should be used. Overwrites the + parameters provided in optimizers_cls, optimizers_kwargs, scheduler_cls and scheduler_kwargs. + Defaults to True. dropout The dropout probability to be used in fully connected layers (default=0). This is compatible with Monte Carlo dropout at inference time for model uncertainty estimation (enabled with From 470bb0881be55567700dd89e62c64396ed499043 Mon Sep 17 00:00:00 2001 From: madtoinou Date: Mon, 31 Oct 2022 16:53:03 +0100 Subject: [PATCH 10/18] simplified logic around default arguments for optimizers and lr_scheduler --- darts/models/forecasting/deeptime.py | 138 +++++++++--------- .../tests/models/forecasting/test_deeptime.py | 8 - 2 files changed, 67 insertions(+), 79 deletions(-) diff --git a/darts/models/forecasting/deeptime.py b/darts/models/forecasting/deeptime.py index 38e26f9cd5..f1077da99a 100644 --- a/darts/models/forecasting/deeptime.py +++ b/darts/models/forecasting/deeptime.py @@ -3,7 +3,7 @@ ------- """ -from typing import List, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import numpy as np import torch @@ -88,7 +88,7 @@ def __init__( num_layers: int, hidden_layers_width: int, n_fourier_feats: int, - scales: list[float], + scales: List[float], dropout: float, activation: str, ): @@ -218,8 +218,7 @@ def __init__( inr_num_layers: int, inr_layers_width: Union[int, List[int]], n_fourier_feats: int, - scales: list[float], - legacy_optimiser: bool, + scales: List[float], dropout: float, activation: str, **kwargs, @@ -242,8 +241,6 @@ def __init__( Number of Fourier components to sample to represent to time-serie in the frequency domain scales Scaling factors applied to the normal distribution sampled for Fourier components' magnitude - legacy_optimiser - Determine if the optimiser policy defined in the original article should be used dropout The dropout probability to be used in fully connected layers (default=0). This is compatible with Monte Carlo dropout at inference time for model uncertainty estimation (enabled with @@ -264,7 +261,6 @@ def __init__( self.inr_layers_width = inr_layers_width self.n_fourier_feats = n_fourier_feats self.scales = scales - self.legacy_optimiser = legacy_optimiser self.dropout = dropout self.activation = activation @@ -428,10 +424,23 @@ def __init__( inr_num_layers: int = 5, inr_layers_width: Union[int, List[int]] = 256, n_fourier_feats: int = 4096, - scales: list[float] = [0.01, 0.1, 1, 5, 10, 20, 50, 100], - legacy_optimiser: bool = True, + scales: List[float] = [0.01, 0.1, 1, 5, 10, 20, 50, 100], dropout: float = 0.1, activation: str = "ReLU", + optimizer_kwargs: Optional[Dict] = { + "lr": 1e-3, + "lambda_lr": 1.0, + "weight_decay": 0.0, + }, + lr_scheduler_kwargs: Optional[Dict] = { + "warmup_epochs": 5, + "eta_min": 0.0, + "scheduler_names": [ + "cosine_annealing", + "cosine_annealing_with_linear_warmup", + "cosine_annealing_with_linear_warmup", + ], + }, **kwargs, ): """Deep time-index model with meta-learning (DeepTIMe). @@ -482,14 +491,19 @@ def __init__( optimizer_cls The PyTorch optimizer class to be used. Default: ``torch.optim.Adam``. optimizer_kwargs - Optionally, some keyword arguments for the PyTorch optimizer (e.g., ``{'lr': 1e-3}`` - for specifying a learning rate). Otherwise the default values of the selected ``optimizer_cls`` - will be used. Default: ``None``. + Optional keyword arguments for the PyTorch optimizer: ``'lr'`` for the INR networks weights and + ``'lambda_lr'`` for the Ridge Regression regularisation term. Otherwise the values from the + original publication will be used. Default: ``{"lr": 1e-3, "lambda_lr": 1.0, "weight_decay": 0.0}``. lr_scheduler_cls - Optionally, the PyTorch learning rate scheduler class to be used. Specifying ``None`` corresponds - to using a constant learning rate. Default: ``None``. + Due to the model architecture, distincts learning rate schedulers can be used for the three groups of + parameters: Ridge Regression regularisation term, (biais and norm) and weights of the INR network. + They must be provided using the lr_scheduler_kwargs argument. lr_scheduler_kwargs - Optionally, some keyword arguments for the PyTorch learning rate scheduler. Default: ``None``. + Optionally, names and keyword arguments for the three learning rate scheduler (respectively Ridge Regression + regularisation term, INR biais and norm, and INR weights. Supported scheduler: "none", "cosine_annealing" + and "cosine_annealing_with_linear_warmup". Default: {"warmup_epochs": 5, "total_epochs": self.n_epochs, + "eta_min": 0.0, "scheduler_names": ["cosine_annealing", "cosine_annealing_with_linear_warmup", + "cosine_annealing_with_linear_warmup"]}. batch_size Number of time series (input and output sequences) used in each training pass. Default: ``32``. n_epochs @@ -566,69 +580,52 @@ def __init__( logger, ) - if legacy_optimiser: - self.pl_module_params["optimizer_kwargs"] = { - "lr": 1e-3, - "lambda_lr": 1.0, - "weight_decay": 0.0, - } - - raise_if_not( - self.n_epochs > 0, - "`self.n_epochs` should be greater than 0, it is used to declare the learning rate scheduler.", - logger, - ) + raise_if_not( + self.n_epochs > 0, + "`self.n_epochs` should be greater than 0, it is used to declare the learning rate schedulers.", + logger, + ) - self.pl_module_params["lr_scheduler_kwargs"] = { - "warmup_epochs": 5, - "total_epochs": self.n_epochs, - "eta_min": 0.0, - "scheduler_names": [ - "cosine_annealing", - "cosine_annealing_with_linear_warmup", - "cosine_annealing_with_linear_warmup", - ], - } - else: - raise_if_not( - "optimizer_kwargs" in self.pl_module_params.keys(), - "`optimizer_kwargs` should contain the following arguments: `weight_decay`, " - "`lambda_lr` and `lr` in order to create the optimizer.", - logger, - ) + raise_if_not( + "optimizer_kwargs" in self.pl_module_params.keys(), + "`optimizer_kwargs` should contain the following arguments: `weight_decay`, " + "`lambda_lr` and `lr` in order to create the optimizer.", + logger, + ) - raise_if_not( - "lr_scheduler_kwargs" in self.pl_module_params.keys(), - "`lr_scheduler_kwargs` should contain the following arguments: `eta_min`, " - "`warmup_epochs` and `scheduler_names` in order to create the optimizer.", - logger, - ) + raise_if_not( + "lr_scheduler_kwargs" in self.pl_module_params.keys(), + "`lr_scheduler_kwargs` should contain the following arguments: `eta_min`, " + "`warmup_epochs` and `scheduler_names` in order to create the optimizer.", + logger, + ) - expected_params = { - "weight_decay", - "lambda_lr", - "lr", - "warmup_epochs", - "eta_min", - "scheduler_names", - } - optimizer_params = self.pl_module_params["optimizer_kwargs"].keys() - scheduler_params = self.pl_module_params["lr_scheduler_kwargs"].keys() - provided_params = set(optimizer_params).union(set(scheduler_params)) - missing_params = expected_params - provided_params - raise_if_not( - not legacy_optimiser and len(missing_params) == 0, - f"Missing argument(s) for the optimiser: {missing_params}. `weight_decay`, " - "`lambda_lr` and `lr` must be defined in `optimizer_kwargs` whereas `Tmax`, " - "`scheduler_names` and `warmup_epochs` must be defined in `scheduler_kwargs`.", - logger, - ) + expected_params = { + "weight_decay", + "lambda_lr", + "lr", + "warmup_epochs", + "eta_min", + "scheduler_names", + } + optimizer_params = self.pl_module_params["optimizer_kwargs"].keys() + scheduler_params = self.pl_module_params["lr_scheduler_kwargs"].keys() + provided_params = set(optimizer_params).union(set(scheduler_params)) + missing_params = expected_params - provided_params + raise_if_not( + len(missing_params) == 0, + f"Missing argument(s) for the optimiser: {missing_params}. `weight_decay`, " + "`lambda_lr` and `lr` must be defined in `optimizer_kwargs` whereas `Tmax`, " + "`scheduler_names` and `warmup_epochs` must be defined in `lr_scheduler_kwargs`.", + logger, + ) + + self.pl_module_params["lr_scheduler_kwargs"]["total_epochs"] = self.n_epochs self.inr_num_layers = inr_num_layers self.inr_layers_width = inr_layers_width self.n_fourier_feats = n_fourier_feats self.scales = scales - self.legacy_optimiser = legacy_optimiser self.dropout = dropout self.activation = activation @@ -657,7 +654,6 @@ def _create_model( inr_layers_width=self.inr_layers_width, n_fourier_feats=self.n_fourier_feats, scales=self.scales, - legacy_optimiser=self.legacy_optimiser, dropout=self.dropout, activation=self.activation, **self.pl_module_params, diff --git a/darts/tests/models/forecasting/test_deeptime.py b/darts/tests/models/forecasting/test_deeptime.py index facaffd1b9..17f901eebd 100644 --- a/darts/tests/models/forecasting/test_deeptime.py +++ b/darts/tests/models/forecasting/test_deeptime.py @@ -50,7 +50,6 @@ def test_fit(self): inr_layers_width=20, n_fourier_feats=64, scales=[0.01, 0.1, 1, 5, 10, 20, 50, 100], - legacy_optimiser=True, random_state=42, ) model.fit(large_ts[:98]) @@ -65,7 +64,6 @@ def test_fit(self): inr_layers_width=20, n_fourier_feats=64, scales=[0.01, 0.1, 1, 5, 10, 20, 50, 100], - legacy_optimiser=True, random_state=42, ) model2.fit(small_ts[:98]) @@ -90,7 +88,6 @@ def test_multivariate(self): inr_layers_width=64, n_fourier_feats=256, # must have enough fourier components to capture low freq scales=[0.01, 0.1, 1, 5, 10, 20, 50, 100], - legacy_optimiser=True, random_state=42, ) model.fit(series_multivariate) @@ -115,7 +112,6 @@ def test_multivariate(self): inr_layers_width=20, n_fourier_feats=64, scales=[0.01, 0.1, 1, 5, 10, 20, 50, 100], - legacy_optimiser=True, random_state=42, ) model.fit(series_multivariate, past_covariates=series_covariates) @@ -139,7 +135,6 @@ def test_deeptime_n_fourier_feats(self): inr_layers_width=20, n_fourier_feats=17, scales=[0.01, 0.1, 1, 5, 10, 20, 50, 100], - legacy_optimiser=True, random_state=42, ) @@ -155,7 +150,6 @@ def test_logtensorboard(self): inr_layers_width=20, n_fourier_feats=64, scales=[0.01, 0.1, 1, 5, 10, 20, 50, 100], - legacy_optimiser=True, random_state=42, ) model.fit(ts) @@ -172,7 +166,6 @@ def test_activation_fns(self): inr_layers_width=20, n_fourier_feats=64, scales=[0.01, 0.1, 1, 5, 10, 20, 50, 100], - legacy_optimiser=True, activation="LeakyReLU", random_state=42, ) @@ -187,7 +180,6 @@ def test_activation_fns(self): inr_layers_width=20, n_fourier_feats=64, scales=[0.01, 0.1, 1, 5, 10, 20, 50, 100], - legacy_optimiser=True, activation="invalid", random_state=42, ) From be7239ea41aa7f4cd5a1a58a60d184e2b225d57c Mon Sep 17 00:00:00 2001 From: madtoinou Date: Mon, 7 Nov 2022 15:35:08 +0100 Subject: [PATCH 11/18] corrected test file according to reviwer comments --- darts/tests/models/forecasting/test_deeptime.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/darts/tests/models/forecasting/test_deeptime.py b/darts/tests/models/forecasting/test_deeptime.py index 17f901eebd..d6aa44e0a0 100644 --- a/darts/tests/models/forecasting/test_deeptime.py +++ b/darts/tests/models/forecasting/test_deeptime.py @@ -14,7 +14,7 @@ TORCH_AVAILABLE = True except ImportError: - logger.warning("Torch not available. TCN tests will be skipped.") + logger.warning("Torch not available. DeepTime tests will be skipped.") TORCH_AVAILABLE = False @@ -29,7 +29,7 @@ def tearDown(self): def test_creation(self): with self.assertRaises(ValueError): - # if a list is passed to the `layer_widths` argument, it must have a length equal to `num_stacks` + # if the `inr_layer_widths` argument is a list, its length must be equal to `inr_num_layers` DeepTimeModel( input_chunk_length=1, output_chunk_length=1, @@ -95,7 +95,6 @@ def test_multivariate(self): # the theoretical result should be [[1.01, 1.02], [0.505, 0.51]]. # We just test if the given result is not too far on average. - print(res) self.assertTrue( abs(np.average(res - np.array([[1.01, 1.02], [0.505, 0.51]])) < 0.03) ) @@ -151,6 +150,8 @@ def test_logtensorboard(self): n_fourier_feats=64, scales=[0.01, 0.1, 1, 5, 10, 20, 50, 100], random_state=42, + log_tensorboard=True, + work_dir=self.temp_work_dir, ) model.fit(ts) model.predict(n=2) From 19104fc3c586591e39ba56dfa8890eee70534e09 Mon Sep 17 00:00:00 2001 From: madtoinou Date: Wed, 16 Nov 2022 10:04:16 +0100 Subject: [PATCH 12/18] correction according to reviewer comments: DeepTime was added to test_global_forecasting_models, the number of epochs during the reduced to the striuct minimum, corrected typo in docstring, removed the mutable default argument, added check in TorchForecastingModel for n_epochs --- darts/models/forecasting/deeptime.py | 67 ++++++++----------- .../forecasting/torch_forecasting_model.py | 6 ++ .../tests/models/forecasting/test_deeptime.py | 17 +++-- .../test_global_forecasting_models.py | 11 +++ 4 files changed, 57 insertions(+), 44 deletions(-) diff --git a/darts/models/forecasting/deeptime.py b/darts/models/forecasting/deeptime.py index f1077da99a..3bd190479a 100644 --- a/darts/models/forecasting/deeptime.py +++ b/darts/models/forecasting/deeptime.py @@ -302,8 +302,8 @@ def forward(self, x_in: Tensor) -> Tensor: def get_coords(self, lookback_len: int, horizon_len: int) -> Tensor: """Return time axis encoded as float values between 0 and 1""" coords = torch.linspace(0, 1, lookback_len + horizon_len) - # would it be clearer with the unsqueeze operator? - return torch.reshape(coords, (1, lookback_len + horizon_len, 1)) + # coords.shape = [1, lookback_len + horizon_len, 1] + return coords.unsqueeze(dim=0).unsqueeze(dim=-1) def configure_optimizers(self): """Override the configure_optimizers to define three groups of parameters, one for the @@ -424,29 +424,17 @@ def __init__( inr_num_layers: int = 5, inr_layers_width: Union[int, List[int]] = 256, n_fourier_feats: int = 4096, - scales: List[float] = [0.01, 0.1, 1, 5, 10, 20, 50, 100], + scales: List[float] = None, dropout: float = 0.1, activation: str = "ReLU", - optimizer_kwargs: Optional[Dict] = { - "lr": 1e-3, - "lambda_lr": 1.0, - "weight_decay": 0.0, - }, - lr_scheduler_kwargs: Optional[Dict] = { - "warmup_epochs": 5, - "eta_min": 0.0, - "scheduler_names": [ - "cosine_annealing", - "cosine_annealing_with_linear_warmup", - "cosine_annealing_with_linear_warmup", - ], - }, + optimizer_kwargs: Optional[Dict] = None, + lr_scheduler_kwargs: Optional[Dict] = None, **kwargs, ): """Deep time-index model with meta-learning (DeepTIMe). This is an implementation of the DeepTime architecture, as outlined in [1]_. The default arguments - correspond to the hyper-parameters described in the published article. + correspond to the hyper-parameters described in the article. This model supports past covariates (known for `input_chunk_length` points before prediction time). @@ -566,6 +554,26 @@ def __init__( # extract pytorch lightning module kwargs self.pl_module_params = self._extract_pl_module_params(**self.model_params) + if scales is None: + scales = [0.01, 0.1, 1, 5, 10, 20, 50, 100] + + if self.pl_module_params["optimizer_kwargs"] is None: + self.pl_module_params["optimizer_kwargs"] = { + "lr": 1e-3, + "lambda_lr": 1.0, + "weight_decay": 0.0, + } + if self.pl_module_params["lr_scheduler_kwargs"] is None: + self.pl_module_params["lr_scheduler_kwargs"] = { + "warmup_epochs": 5, + "eta_min": 0.0, + "scheduler_names": [ + "cosine_annealing", + "cosine_annealing_with_linear_warmup", + "cosine_annealing_with_linear_warmup", + ], + } + raise_if_not( isinstance(inr_layers_width, int) or len(inr_layers_width) == inr_num_layers, @@ -580,26 +588,6 @@ def __init__( logger, ) - raise_if_not( - self.n_epochs > 0, - "`self.n_epochs` should be greater than 0, it is used to declare the learning rate schedulers.", - logger, - ) - - raise_if_not( - "optimizer_kwargs" in self.pl_module_params.keys(), - "`optimizer_kwargs` should contain the following arguments: `weight_decay`, " - "`lambda_lr` and `lr` in order to create the optimizer.", - logger, - ) - - raise_if_not( - "lr_scheduler_kwargs" in self.pl_module_params.keys(), - "`lr_scheduler_kwargs` should contain the following arguments: `eta_min`, " - "`warmup_epochs` and `scheduler_names` in order to create the optimizer.", - logger, - ) - expected_params = { "weight_decay", "lambda_lr", @@ -615,7 +603,7 @@ def __init__( raise_if_not( len(missing_params) == 0, f"Missing argument(s) for the optimiser: {missing_params}. `weight_decay`, " - "`lambda_lr` and `lr` must be defined in `optimizer_kwargs` whereas `Tmax`, " + "`lambda_lr` and `lr` must be defined in `optimizer_kwargs` whereas `eta_min`, " "`scheduler_names` and `warmup_epochs` must be defined in `lr_scheduler_kwargs`.", logger, ) @@ -645,7 +633,6 @@ def _create_model( output_dim = train_sample[-1].shape[1] nr_params = 1 if self.likelihood is None else self.likelihood.num_parameters - # TODO: support nr_params functionnality model = _DeepTimeModule( input_dim=input_dim, output_dim=output_dim, diff --git a/darts/models/forecasting/torch_forecasting_model.py b/darts/models/forecasting/torch_forecasting_model.py index bc7d973b4d..d67dcd46e0 100644 --- a/darts/models/forecasting/torch_forecasting_model.py +++ b/darts/models/forecasting/torch_forecasting_model.py @@ -250,6 +250,12 @@ def __init__( super().__init__(add_encoders=add_encoders) suppress_lightning_warnings(suppress_all=not show_warnings) + raise_if_not( + n_epochs > 0, + "`n_epochs` should be greater than 0.", + logger, + ) + # We will fill these dynamically, upon first call of fit_from_dataset(): self.model: Optional[PLForecastingModule] = None self.train_sample: Optional[Tuple] = None diff --git a/darts/tests/models/forecasting/test_deeptime.py b/darts/tests/models/forecasting/test_deeptime.py index d6aa44e0a0..cc4a675d19 100644 --- a/darts/tests/models/forecasting/test_deeptime.py +++ b/darts/tests/models/forecasting/test_deeptime.py @@ -36,6 +36,15 @@ def test_creation(self): inr_num_layers=3, inr_layers_width=[1], ) + with self.assertRaises(ValueError): + # n_epochs should be greater than 0 for instantiation of the lr schedulers + DeepTimeModel( + input_chunk_length=1, + output_chunk_length=1, + inr_num_layers=3, + inr_layers_width=20, + n_epochs=0, + ) def test_fit(self): large_ts = tg.constant_timeseries(length=100, value=1000) @@ -129,7 +138,7 @@ def test_deeptime_n_fourier_feats(self): DeepTimeModel( input_chunk_length=1, output_chunk_length=1, - n_epochs=10, + n_epochs=1, inr_num_layers=2, inr_layers_width=20, n_fourier_feats=17, @@ -144,7 +153,7 @@ def test_logtensorboard(self): model = DeepTimeModel( input_chunk_length=1, output_chunk_length=1, - n_epochs=10, + n_epochs=1, inr_num_layers=2, inr_layers_width=20, n_fourier_feats=64, @@ -162,7 +171,7 @@ def test_activation_fns(self): model = DeepTimeModel( input_chunk_length=1, output_chunk_length=1, - n_epochs=10, + n_epochs=1, inr_num_layers=2, inr_layers_width=20, n_fourier_feats=64, @@ -176,7 +185,7 @@ def test_activation_fns(self): model = DeepTimeModel( input_chunk_length=1, output_chunk_length=1, - n_epochs=10, + n_epochs=1, inr_num_layers=2, inr_layers_width=20, n_fourier_feats=64, diff --git a/darts/tests/models/forecasting/test_global_forecasting_models.py b/darts/tests/models/forecasting/test_global_forecasting_models.py index 24923017f4..ade18ae3c3 100644 --- a/darts/tests/models/forecasting/test_global_forecasting_models.py +++ b/darts/tests/models/forecasting/test_global_forecasting_models.py @@ -22,6 +22,7 @@ from darts.models import ( BlockRNNModel, + DeepTimeModel, NBEATSModel, RNNModel, TCNModel, @@ -102,6 +103,16 @@ }, 100.0, ), + ( + DeepTimeModel, + { + "inr_num_layers": 2, + "inr_layers_width": 16, + "n_fourier_feats": 16, + "n_epochs": 10, + }, + 50.0, + ), ] class GlobalForecastingModelsTestCase(DartsBaseTestClass): From fead1eacc0842d9176139e1ee081dbe53962b8c0 Mon Sep 17 00:00:00 2001 From: madtoinou Date: Wed, 16 Nov 2022 23:57:02 +0100 Subject: [PATCH 13/18] fix: added transpose in the target array for multivariate test and increase the length of the prediction from 2 to 3, last version was relying on erroneous broadcasting --- darts/tests/models/forecasting/test_deeptime.py | 11 ++++++++--- darts/tests/models/forecasting/test_nbeats_nhits.py | 7 +++++-- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/darts/tests/models/forecasting/test_deeptime.py b/darts/tests/models/forecasting/test_deeptime.py index cc4a675d19..58f6137549 100644 --- a/darts/tests/models/forecasting/test_deeptime.py +++ b/darts/tests/models/forecasting/test_deeptime.py @@ -100,12 +100,17 @@ def test_multivariate(self): random_state=42, ) model.fit(series_multivariate) - res = model.predict(n=2).values() + res = model.predict(n=3).values() - # the theoretical result should be [[1.01, 1.02], [0.505, 0.51]]. + # the theoretical result should be [[1.01, 1.02, 1.03], [0.505, 0.51, 0.515]]. # We just test if the given result is not too far on average. self.assertTrue( - abs(np.average(res - np.array([[1.01, 1.02], [0.505, 0.51]])) < 0.03) + abs( + np.average( + res - np.array([[1.01, 1.02, 1.03], [0.505, 0.51, 0.515]]).T + ) + < 0.03 + ) ) # Test Covariates diff --git a/darts/tests/models/forecasting/test_nbeats_nhits.py b/darts/tests/models/forecasting/test_nbeats_nhits.py index b14e23b4c7..0d2797dcdc 100644 --- a/darts/tests/models/forecasting/test_nbeats_nhits.py +++ b/darts/tests/models/forecasting/test_nbeats_nhits.py @@ -97,13 +97,16 @@ def test_multivariate(self): ) model.fit(series_multivariate) - res = model.predict(n=2).values() + res = model.predict(n=3).values() # the theoretical result should be [[1.01, 1.02], [0.505, 0.51]]. # We just test if the given result is not too far on average. self.assertTrue( abs( - np.average(res - np.array([[1.01, 1.02], [0.505, 0.51]])) < 0.03 + np.average( + res - np.array([[1.01, 1.02, 1.03], [0.505, 0.51, 0.515]]).T + ) + < 0.03 ) ) From 56dc64e83d17aadbcdff163e80e2e4b7103e9d2d Mon Sep 17 00:00:00 2001 From: madtoinou Date: Thu, 17 Nov 2022 00:00:56 +0100 Subject: [PATCH 14/18] feat: draft of the probabilistic predictions for deeptime, tests are failing at the moment --- darts/models/forecasting/deeptime.py | 64 ++++++++++++++----- .../forecasting/test_probabilistic_models.py | 15 ++++- 2 files changed, 62 insertions(+), 17 deletions(-) diff --git a/darts/models/forecasting/deeptime.py b/darts/models/forecasting/deeptime.py index 3bd190479a..d22f211359 100644 --- a/darts/models/forecasting/deeptime.py +++ b/darts/models/forecasting/deeptime.py @@ -91,6 +91,7 @@ def __init__( scales: List[float], dropout: float, activation: str, + nr_params: int, ): """Implicit Neural Representation, mapping values to their coordinates using a Multi-Layer Perceptron. @@ -132,6 +133,7 @@ def __init__( self.n_fourier_feats = n_fourier_feats self.scales = scales self.dropout = dropout + self.nr_params = nr_params raise_if_not( activation in ACTIVATIONS, f"'{activation}' is not in {ACTIVATIONS}" @@ -156,7 +158,7 @@ def __init__( # TODO : solve ambiguity between the num of layers and the number of hidden layers last_width = feats_size linear_layer_stack_list = [] - for layer_width in self.hidden_layers_width: + for layer_width in self.hidden_layers_width[:-1]: linear_layer_stack_list.append(nn.Linear(last_width, layer_width)) linear_layer_stack_list.append(self.activation) @@ -167,6 +169,17 @@ def __init__( last_width = layer_width + # output width multiplied self.nr_params to have one time encoding per param + linear_layer_stack_list.append( + nn.Linear(last_width, self.hidden_layers_width[-1] * self.nr_params) + ) + linear_layer_stack_list.append(self.activation) + if self.dropout > 0: + linear_layer_stack_list.append(MonteCarloDropout(p=self.dropout)) + linear_layer_stack_list.append( + nn.LayerNorm(self.hidden_layers_width[-1] * self.nr_params) + ) + self.layers = nn.Sequential(*linear_layer_stack_list) def forward(self, x: Tensor) -> Tensor: @@ -273,30 +286,49 @@ def __init__( scales=self.scales, dropout=self.dropout, activation=self.activation, + nr_params=self.nr_params, ) self.adaptive_weights = RidgeRegressor() def forward(self, x_in: Tensor) -> Tensor: x, _ = x_in # x_in: (past_target|past_covariate, static_covariates) - batch_size, _, _ = x.shape # x: (batch_size, in_len, in_dim) - coords = self.get_coords(self.input_chunk_length, self.output_chunk_length) + # repeat values for each nr_params + # x = x.repeat(1, 1, self.nr_params) + coords = self.get_coords(self.input_chunk_length, self.output_chunk_length) time_reprs = self.inr(coords) - time_reprs = time_reprs.repeat(batch_size, 1, self.input_dim) - - lookback_reprs = time_reprs[:, : -self.output_chunk_length] - horizon_reprs = time_reprs[:, -self.output_chunk_length :] - # learn weights from the lookback - w, b = self.adaptive_weights(lookback_reprs, x) - # apply weights to the horizon - y = torch.einsum("... d o, ... t d -> ... t o", [w, horizon_reprs]) + b - - # taken from nbeats, could be simplified - y = y.view( - y.shape[0], self.output_chunk_length, self.input_dim, self.nr_params - )[:, :, : self.output_dim, :] + # time_reprs.shape = [b_size, input_chunk_len+output_chunk_len, self.hidden_layer_width[-1]*nr_params] + time_reprs = time_reprs.repeat(batch_size, 1, 1) + time_reprs = time_reprs.reshape( + batch_size, + self.input_chunk_length + self.output_chunk_length, + -1, + self.nr_params, + ) + + # must use a different time_reprs (A) for each nr_param so that the linear equation changes: + # AX = B where A is the diag of lookback_reprs.T*lookback_reprs and lookback_reprs.T*x + # the parameter lambda of the RidgeRegressor is shared across the nr_params + forecasts = [] + for i in range(self.nr_params): + lookback_reprs = time_reprs[:, : -self.output_chunk_length, :, i] + horizon_reprs = time_reprs[:, -self.output_chunk_length :, :, i] + + # learn weights from the lookback + w, b = self.adaptive_weights(lookback_reprs, x) + # apply weights to the horizon + forecast = torch.einsum("b d o, b t d -> b t o", [w, horizon_reprs]) + b + # forecast.shape = [batch, output_chunk_size, input_dim] + forecasts.append(forecast) + + # y.shape = [batch, output_chunk_size, input_dim, nr_params] + y = torch.stack(forecasts, dim=-1) + # retain forecast of target (exclude past/static covariates) + y = y[:, :, : self.output_dim * self.nr_params, :] + # TODO: check that target predictions are the first self.output_dim*self.nr_params values, change slicing? + # TODO: verify that the model benefits from covariates (potentially in the INR?!) return y def get_coords(self, lookback_len: int, horizon_len: int) -> Tensor: diff --git a/darts/tests/models/forecasting/test_probabilistic_models.py b/darts/tests/models/forecasting/test_probabilistic_models.py index 6ecad577bd..6db3a05029 100644 --- a/darts/tests/models/forecasting/test_probabilistic_models.py +++ b/darts/tests/models/forecasting/test_probabilistic_models.py @@ -3,7 +3,7 @@ from darts import TimeSeries from darts.logging import get_logger from darts.metrics import mae -from darts.models import ARIMA, BATS, TBATS, ExponentialSmoothing +from darts.models import ARIMA, BATS, TBATS, DeepTimeModel, ExponentialSmoothing from darts.models.forecasting.forecasting_model import GlobalForecastingModel from darts.tests.base_test_class import DartsBaseTestClass from darts.utils import timeseries_generation as tg @@ -132,6 +132,19 @@ }, 1, ), + ( + DeepTimeModel, + { + "input_chunk_length": 10, + "output_chunk_length": 5, + "inr_num_layers": 2, + "inr_layers_width": 16, + "n_fourier_feats": 16, + "n_epochs": 10, + "likelihood": GaussianLikelihood(), + }, + 1, + ), ] From 4351ec5ef74fd16e1c58cfb309570905d6886599 Mon Sep 17 00:00:00 2001 From: madtoinou Date: Thu, 17 Nov 2022 09:26:38 +0100 Subject: [PATCH 15/18] forgot to set the random_state in the test_probabilistic_models for DeepTime, removed some comments in the forward method, corrected typo --- darts/models/forecasting/deeptime.py | 8 +++----- .../models/forecasting/test_probabilistic_models.py | 11 ++++++----- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/darts/models/forecasting/deeptime.py b/darts/models/forecasting/deeptime.py index d22f211359..86c05581a2 100644 --- a/darts/models/forecasting/deeptime.py +++ b/darts/models/forecasting/deeptime.py @@ -294,12 +294,10 @@ def __init__( def forward(self, x_in: Tensor) -> Tensor: x, _ = x_in # x_in: (past_target|past_covariate, static_covariates) batch_size, _, _ = x.shape # x: (batch_size, in_len, in_dim) - # repeat values for each nr_params - # x = x.repeat(1, 1, self.nr_params) coords = self.get_coords(self.input_chunk_length, self.output_chunk_length) time_reprs = self.inr(coords) - # time_reprs.shape = [b_size, input_chunk_len+output_chunk_len, self.hidden_layer_width[-1]*nr_params] + # time_reprs.shape = [batch_size, input_chunk_len+output_chunk_len, inr_layers_width[-1]*nr_params] time_reprs = time_reprs.repeat(batch_size, 1, 1) time_reprs = time_reprs.reshape( batch_size, @@ -309,7 +307,7 @@ def forward(self, x_in: Tensor) -> Tensor: ) # must use a different time_reprs (A) for each nr_param so that the linear equation changes: - # AX = B where A is the diag of lookback_reprs.T*lookback_reprs and lookback_reprs.T*x + # AX = B where A is the diag of lookback_reprs.T*lookback_reprs and B is lookback_reprs.T*x # the parameter lambda of the RidgeRegressor is shared across the nr_params forecasts = [] for i in range(self.nr_params): @@ -328,7 +326,7 @@ def forward(self, x_in: Tensor) -> Tensor: # retain forecast of target (exclude past/static covariates) y = y[:, :, : self.output_dim * self.nr_params, :] # TODO: check that target predictions are the first self.output_dim*self.nr_params values, change slicing? - # TODO: verify that the model benefits from covariates (potentially in the INR?!) + # TODO: run experiments to check if the model benefits from covariates (potentially in the INR?!) return y def get_coords(self, lookback_len: int, horizon_len: int) -> Tensor: diff --git a/darts/tests/models/forecasting/test_probabilistic_models.py b/darts/tests/models/forecasting/test_probabilistic_models.py index 6db3a05029..ab681b0918 100644 --- a/darts/tests/models/forecasting/test_probabilistic_models.py +++ b/darts/tests/models/forecasting/test_probabilistic_models.py @@ -137,13 +137,14 @@ { "input_chunk_length": 10, "output_chunk_length": 5, - "inr_num_layers": 2, - "inr_layers_width": 16, - "n_fourier_feats": 16, - "n_epochs": 10, + "inr_num_layers": 5, + "inr_layers_width": 256, + "n_fourier_feats": 1024, + "n_epochs": 20, + "random_state": 0, "likelihood": GaussianLikelihood(), }, - 1, + 1.9, ), ] From f81ccbc6a252b3326ed27af9084541817d63edd7 Mon Sep 17 00:00:00 2001 From: madtoinou Date: Fri, 25 Nov 2022 16:13:42 +0100 Subject: [PATCH 16/18] fix: removed outdated comment, added comment about optimizer and scheduler default argument, added check for scheduler warmup_epochs value, removed hardcoded typing of ridge regressor regularization term --- darts/models/forecasting/deeptime.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/darts/models/forecasting/deeptime.py b/darts/models/forecasting/deeptime.py index 86c05581a2..66f72ed856 100644 --- a/darts/models/forecasting/deeptime.py +++ b/darts/models/forecasting/deeptime.py @@ -191,7 +191,7 @@ class RidgeRegressor(nn.Module): def __init__(self, lambda_init: float = 0.0): """Implementation of the closed form Ridge Regression with a regularization coefficient.""" super().__init__() - self._lambda = nn.Parameter(torch.as_tensor(lambda_init, dtype=torch.float)) + self._lambda = nn.Parameter(torch.as_tensor(lambda_init)) def forward(self, reprs: Tensor, x: Tensor, reg_coeff: float = None) -> Tensor: if reg_coeff is None: @@ -268,7 +268,6 @@ def __init__( super().__init__(**kwargs) self.input_dim = input_dim self.output_dim = output_dim - # TODO: develop support for nr_params functionality self.nr_params = nr_params self.inr_num_layers = inr_num_layers self.inr_layers_width = inr_layers_width @@ -436,15 +435,6 @@ def cosine_annealing_with_linear_warmup(current_epoch): "monitor": "val_loss", } - # TODO: must be implemented since inheriting from PLPastCovariatesModule - """ - def _produce_train_output(): - raise NotImplementedError() - - def _get_batch_prediction(): - raise NotImplementedError() - """ - class DeepTimeModel(PastCovariatesTorchModel): def __init__( @@ -618,6 +608,7 @@ def __init__( logger, ) + # user can either use default arguments or must redefine all of them expected_params = { "weight_decay", "lambda_lr", @@ -640,6 +631,15 @@ def __init__( self.pl_module_params["lr_scheduler_kwargs"]["total_epochs"] = self.n_epochs + raise_if_not( + self.n_epochs + > self.pl_module_params["lr_scheduler_kwargs"]["warmup_epochs"], + f"n_epochs ({self.n_epochs}) must be greater than the number of warmup epochs for the " + f"learning rate scheduler ({self.pl_module_params['lr_scheduler_kwargs']['warmup_epochs']}). " + f" This value is controlled by the `lr_scheduler_kwargs['warmup_epochs']` argument.", + logger, + ) + self.inr_num_layers = inr_num_layers self.inr_layers_width = inr_layers_width self.n_fourier_feats = n_fourier_feats From bff3b66a9c0922813e862531b78b391891cab4c5 Mon Sep 17 00:00:00 2001 From: madtoinou Date: Fri, 25 Nov 2022 18:10:37 +0100 Subject: [PATCH 17/18] corrected tests after adding a check on the value of n_epochs compared to warmup_epochs --- .../tests/models/forecasting/test_deeptime.py | 29 ++++++++++++------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/darts/tests/models/forecasting/test_deeptime.py b/darts/tests/models/forecasting/test_deeptime.py index 58f6137549..966ede3702 100644 --- a/darts/tests/models/forecasting/test_deeptime.py +++ b/darts/tests/models/forecasting/test_deeptime.py @@ -45,6 +45,15 @@ def test_creation(self): inr_layers_width=20, n_epochs=0, ) + with self.assertRaises(ValueError): + # n_epochs should be greater than warmup_epochs of lr schedulers + DeepTimeModel( + input_chunk_length=1, + output_chunk_length=1, + inr_num_layers=3, + inr_layers_width=20, + n_epochs=1, + ) def test_fit(self): large_ts = tg.constant_timeseries(length=100, value=1000) @@ -143,7 +152,7 @@ def test_deeptime_n_fourier_feats(self): DeepTimeModel( input_chunk_length=1, output_chunk_length=1, - n_epochs=1, + n_epochs=6, inr_num_layers=2, inr_layers_width=20, n_fourier_feats=17, @@ -158,7 +167,7 @@ def test_logtensorboard(self): model = DeepTimeModel( input_chunk_length=1, output_chunk_length=1, - n_epochs=1, + n_epochs=6, inr_num_layers=2, inr_layers_width=20, n_fourier_feats=64, @@ -176,11 +185,11 @@ def test_activation_fns(self): model = DeepTimeModel( input_chunk_length=1, output_chunk_length=1, - n_epochs=1, + n_epochs=6, inr_num_layers=2, - inr_layers_width=20, - n_fourier_feats=64, - scales=[0.01, 0.1, 1, 5, 10, 20, 50, 100], + inr_layers_width=8, + n_fourier_feats=8, + scales=[0.01, 0.1], activation="LeakyReLU", random_state=42, ) @@ -190,11 +199,11 @@ def test_activation_fns(self): model = DeepTimeModel( input_chunk_length=1, output_chunk_length=1, - n_epochs=1, + n_epochs=6, inr_num_layers=2, - inr_layers_width=20, - n_fourier_feats=64, - scales=[0.01, 0.1, 1, 5, 10, 20, 50, 100], + inr_layers_width=8, + n_fourier_feats=8, + scales=[0.01, 0.1], activation="invalid", random_state=42, ) From 3f9e944678b5ddcff64ec57f4e293f842f81eb5d Mon Sep 17 00:00:00 2001 From: madtoinou Date: Mon, 28 Nov 2022 08:38:02 +0100 Subject: [PATCH 18/18] fix: corrected msising parenthesis due to auto-merging --- .../tests/models/forecasting/test_global_forecasting_models.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/darts/tests/models/forecasting/test_global_forecasting_models.py b/darts/tests/models/forecasting/test_global_forecasting_models.py index cf66c032a5..cca604a8e3 100644 --- a/darts/tests/models/forecasting/test_global_forecasting_models.py +++ b/darts/tests/models/forecasting/test_global_forecasting_models.py @@ -114,6 +114,8 @@ "n_epochs": 10, }, 50.0, + ), + ( NLinearModel, { "n_epochs": 10,