Skip to content

Commit

Permalink
clean up unused code
Browse files Browse the repository at this point in the history
  • Loading branch information
JanFidor committed Jul 17, 2023
1 parent 92b4cde commit c5ddf47
Showing 1 changed file with 5 additions and 11 deletions.
16 changes: 5 additions & 11 deletions darts/models/forecasting/block_rnn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,16 +132,15 @@ def _rnn_sequence(
modules = []
for i in range(num_layers):
input = input_size if (i == 0) else hidden_dim
is_last = i == num_layers - 1
rnn = getattr(nn, name)(input, hidden_dim, 1, batch_first=True)

modules.append(rnn)
modules.append(ExtractRnnOutput())

if normalization:
modules.append(self._normalization_layer(normalization, hidden_dim))
if (
i < num_layers - 1
): # pytorch RNNs don't have dropout applied on the last layer
if is_last: # pytorch RNNs don't have dropout applied on the last layer
modules.append(nn.Dropout(dropout))
return nn.Sequential(*modules)

Expand All @@ -161,20 +160,15 @@ def _fc_layer(
self.output_chunk_length * target_size * self.nr_params
]:
if normalization:
feats.append(self._normalization_layer(normalization, last, False))
feats.append(self._normalization_layer(normalization, last))
feats.append(nn.Linear(last, feature))
last = feature
return nn.Sequential(*feats)

def _normalization_layer(
self, normalization: str, hidden_size: int, is_temporal: bool
):
def _normalization_layer(self, normalization: str, hidden_size: int):

if normalization == "batch":
if is_temporal:
return TemporalBatchNorm1d(hidden_size)
else:
return nn.BatchNorm1d(hidden_size)
return TemporalBatchNorm1d(hidden_size)
elif normalization == "layer":
return nn.LayerNorm(hidden_size)

Expand Down

0 comments on commit c5ddf47

Please sign in to comment.