diff --git a/darts/models/forecasting/transformer_model.py b/darts/models/forecasting/transformer_model.py index f98fbb327e..c49e809ded 100644 --- a/darts/models/forecasting/transformer_model.py +++ b/darts/models/forecasting/transformer_model.py @@ -8,6 +8,8 @@ import torch import torch.nn as nn +import torch.nn.functional as F +from torch.nn import Transformer from darts.logging import get_logger, raise_if, raise_if_not, raise_log from darts.models.components import glu_variants, layer_norm_variants @@ -188,6 +190,7 @@ def __init__( self.target_size = output_size self.nr_params = nr_params self.target_length = self.output_chunk_length + self.d_model = d_model self.encoder = nn.Linear(input_size, d_model) self.positional_encoding = _PositionalEncoding( @@ -281,48 +284,105 @@ def __init__( custom_decoder=custom_decoder, ) - self.decoder = nn.Linear( - d_model, self.target_length * self.target_size * self.nr_params - ) - - def _create_transformer_inputs(self, data): - # '_TimeSeriesSequentialDataset' stores time series in the - # (batch_size, input_chunk_length, input_size) format. PyTorch's nn.Transformer - # module needs it the (input_chunk_length, batch_size, input_size) format. - # Therefore, the first two dimensions need to be swapped. - src = data.permute(1, 0, 2) - tgt = src[-1:, :, :] - - return src, tgt + self.decoder = nn.Linear(d_model, self.target_size * self.nr_params) @io_processor def forward(self, x_in: Tuple): - data, _ = x_in - # Here we create 'src' and 'tgt', the inputs for the encoder and decoder - # side of the Transformer architecture - src, tgt = self._create_transformer_inputs(data) + """ + During training (teacher forcing) x_in = tuple(past_target + past_covariates, static_covariates, future_targets) + During inference x_in = tuple(past_target + past_covariates, static_covariates) + '_TimeSeriesSequentialDataset' stores time series in the + (batch_size, input_chunk_length, input_size) format. PyTorch's nn.Transformer + module needs it the (input_chunk_length, batch_size, input_size) format. + Therefore, the first two dimensions need to be swapped. + """ + src = x_in[0].permute(1, 0, 2) + pad_size = (0, self.input_size - self.target_size) + + # start token consists only of target series, past covariates are substituted with 0 padding + start_token = src[-1:, :, : self.target_size] + start_token_padded = F.pad(start_token, pad_size) + + if len(x_in) == 3: + tgt = x_in[-1].permute(1, 0, 2) + tgt = F.pad(tgt, pad_size) + tgt = torch.cat([start_token_padded, tgt], dim=0) + return self._prediction_step(src, tgt)[:, :-1, :, :] + + tgt = start_token_padded + + predictions = [] + for _ in range(self.target_length): + pred = self._prediction_step(src, tgt)[:, -1, :, :] + predictions.append(pred) + tgt = torch.cat( + [tgt, F.pad(pred.mean(dim=2).unsqueeze(dim=0), pad_size)], + dim=0, + ) # take average of samples + return torch.stack(predictions, dim=1) + + def _prediction_step(self, src: torch.Tensor, tgt: torch.Tensor): + target_length = tgt.shape[0] + device, tensor_type = src.device, src.dtype # "math.sqrt(self.input_size)" is a normalization factor # see section 3.2.1 in 'Attention is All you Need' by Vaswani et al. (2017) - src = self.encoder(src) * math.sqrt(self.input_size) - src = self.positional_encoding(src) + src = self.encoder(src) * math.sqrt(self.d_model) + tgt = self.encoder(tgt) * math.sqrt(self.d_model) - tgt = self.encoder(tgt) * math.sqrt(self.input_size) + src = self.positional_encoding(src) tgt = self.positional_encoding(tgt) - x = self.transformer(src=src, tgt=tgt) + tgt_mask = Transformer.generate_square_subsequent_mask( + target_length, device + ).to(dtype=tensor_type) + + x = self.transformer(src=src, tgt=tgt, tgt_mask=tgt_mask) out = self.decoder(x) # Here we change the data format # from (1, batch_size, output_chunk_length * output_size) # to (batch_size, output_chunk_length, output_size, nr_params) - predictions = out[0, :, :] + predictions = out.permute(1, 0, 2) predictions = predictions.view( - -1, self.target_length, self.target_size, self.nr_params + -1, target_length, self.target_size, self.nr_params ) return predictions + def training_step(self, train_batch, batch_idx) -> torch.Tensor: + """performs the training step""" + train_batch = list(train_batch) + future_targets = train_batch[-1] + train_batch.append(future_targets) + return super().training_step(train_batch, batch_idx) + + def _produce_train_output(self, input_batch: Tuple): + """ + Feeds PastCovariatesTorchModel with input and output chunks of a PastCovariatesSequentialDataset for + training. + + Parameters: + input_batch + ``(past_target, past_covariates, static_covariates, future_target)`` during training + + ``(past_target, past_covariates, static_covariates)`` during validation (not teacher forced) + """ + + past_target, past_covariates, static_covariates = input_batch[:3] + # Currently all our PastCovariates models require past target and covariates concatenated + inpt = [ + torch.cat([past_target, past_covariates], dim=2) + if past_covariates is not None + else past_target, + static_covariates, + ] + + # add future targets when training (teacher forcing) + if len(input_batch) == 4: + inpt.append(input_batch[-1]) + return self(inpt) + class TransformerModel(PastCovariatesTorchModel): def __init__( @@ -351,7 +411,7 @@ def __init__( The multi-head attention mechanism is highly parallelizable, which makes the transformer architecture very suitable to be trained with GPUs. - The transformer architecture implemented here is based on [1]_. + The transformer architecture implemented here is based on [1]_ and uses teacher forcing [4]_. This model supports past covariates (known for `input_chunk_length` points before prediction time). @@ -388,7 +448,7 @@ def __init__( Fraction of neurons affected by Dropout (default=0.1). activation The activation function of encoder/decoder intermediate layer, (default='relu'). - can be one of the glu variant's FeedForward Network (FFN)[2]. A feedforward network is a + can be one of the glu variant's FeedForward Network (FFN)[3]. A feedforward network is a fully-connected layer with an activation. The glu variant's FeedForward Network are a series of FFNs designed to work better with Transformer based models. ["GLU", "Bilinear", "ReGLU", "GEGLU", "SwiGLU", "ReLU", "GELU"] or one the pytorch internal activations ["relu", "gelu"] @@ -541,17 +601,7 @@ def encode_year(idx): .. [2] Shazeer, Noam, "GLU Variants Improve Transformer", 2020. arVix https://arxiv.org/abs/2002.05202. .. [3] T. Kim et al. "Reversible Instance Normalization for Accurate Time-Series Forecasting against Distribution Shift", https://openreview.net/forum?id=cGDAkQo1C0p - - Notes - ----- - Disclaimer: - This current implementation is fully functional and can already produce some good predictions. However, - it is still limited in how it uses the Transformer architecture because the `tgt` input of - `torch.nn.Transformer` is not utilized to its full extent. Currently, we simply pass the last value of the - `src` input to `tgt`. To get closer to the way the Transformer is usually used in language models, we - should allow the model to consume its own output as part of the `tgt` argument, such that when predicting - sequences of values, the input to the `tgt` argument would grow as outputs of the transformer model would be - added to it. Of course, the training of the model would have to be adapted accordingly. + .. [4] Teacher Forcing PyTorch tutorial: https://github.com/pytorch/examples/tree/main/word_language_model Examples -------- diff --git a/darts/tests/models/forecasting/test_probabilistic_models.py b/darts/tests/models/forecasting/test_probabilistic_models.py index eaae6c0c9f..1bdc8bb56e 100644 --- a/darts/tests/models/forecasting/test_probabilistic_models.py +++ b/darts/tests/models/forecasting/test_probabilistic_models.py @@ -142,6 +142,7 @@ "output_chunk_length": 5, "n_epochs": 20, "random_state": 0, + "norm_type": "LayerNorm", "likelihood": GaussianLikelihood(), **tfm_kwargs, },