Skip to content

Commit

Permalink
pass hidden state to fc layer
Browse files Browse the repository at this point in the history
  • Loading branch information
JanFidor committed Jul 17, 2023
1 parent c5ddf47 commit ccb204d
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 7 deletions.
10 changes: 6 additions & 4 deletions darts/models/forecasting/block_rnn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def forward(self, x_in: Tuple):

""" Here, we apply the FC network only on the last output point (at the last time step)
"""
predictions = hidden[:, -1, :]
predictions = hidden[-1, :, :]
predictions = self.fc(predictions)
predictions = predictions.view(
batch_size, self.out_len, self.target_size, self.nr_params
Expand All @@ -130,18 +130,20 @@ def _rnn_sequence(
):

modules = []
is_lstm = self.name == "LSTM"
for i in range(num_layers):
input = input_size if (i == 0) else hidden_dim
is_last = i == num_layers - 1
rnn = getattr(nn, name)(input, hidden_dim, 1, batch_first=True)

modules.append(rnn)
modules.append(ExtractRnnOutput())

modules.append(ExtractRnnOutput(not is_last, is_lstm))
modules.append(nn.Dropout(dropout))
if normalization:
modules.append(self._normalization_layer(normalization, hidden_dim))
if is_last: # pytorch RNNs don't have dropout applied on the last layer
if not is_last: # pytorch RNNs don't have dropout applied on the last layer
modules.append(nn.Dropout(dropout))

return nn.Sequential(*modules)

def _fc_layer(
Expand Down
12 changes: 9 additions & 3 deletions darts/utils/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,15 @@ def _reshape_input(self, x):


class ExtractRnnOutput(nn.Module):
def __init__(self) -> None:
def __init__(self, is_output, is_lstm) -> None:
self.is_output = is_output
self.is_lstm = is_lstm
super().__init__()

def forward(self, input):
output, _ = input
return output
output, hidden = input
if self.is_output:
return output
if self.is_lstm:
return hidden[0]
return hidden

0 comments on commit ccb204d

Please sign in to comment.