diff --git a/darts/models/forecasting/block_rnn_model.py b/darts/models/forecasting/block_rnn_model.py index 50bcb10d25..b667dffe94 100644 --- a/darts/models/forecasting/block_rnn_model.py +++ b/darts/models/forecasting/block_rnn_model.py @@ -110,7 +110,7 @@ def forward(self, x_in: Tuple): """ Here, we apply the FC network only on the last output point (at the last time step) """ - predictions = hidden[:, -1, :] + predictions = hidden[-1, :, :] predictions = self.fc(predictions) predictions = predictions.view( batch_size, self.out_len, self.target_size, self.nr_params @@ -130,18 +130,20 @@ def _rnn_sequence( ): modules = [] + is_lstm = self.name == "LSTM" for i in range(num_layers): input = input_size if (i == 0) else hidden_dim is_last = i == num_layers - 1 rnn = getattr(nn, name)(input, hidden_dim, 1, batch_first=True) modules.append(rnn) - modules.append(ExtractRnnOutput()) - + modules.append(ExtractRnnOutput(not is_last, is_lstm)) + modules.append(nn.Dropout(dropout)) if normalization: modules.append(self._normalization_layer(normalization, hidden_dim)) - if is_last: # pytorch RNNs don't have dropout applied on the last layer + if not is_last: # pytorch RNNs don't have dropout applied on the last layer modules.append(nn.Dropout(dropout)) + return nn.Sequential(*modules) def _fc_layer( diff --git a/darts/utils/torch.py b/darts/utils/torch.py index d70141aba1..3e6934afc9 100644 --- a/darts/utils/torch.py +++ b/darts/utils/torch.py @@ -131,9 +131,15 @@ def _reshape_input(self, x): class ExtractRnnOutput(nn.Module): - def __init__(self) -> None: + def __init__(self, is_output, is_lstm) -> None: + self.is_output = is_output + self.is_lstm = is_lstm super().__init__() def forward(self, input): - output, _ = input - return output + output, hidden = input + if self.is_output: + return output + if self.is_lstm: + return hidden[0] + return hidden