Skip to content

Commit

Permalink
add bug fix after merging
Browse files Browse the repository at this point in the history
  • Loading branch information
JanFidor committed Aug 7, 2023
1 parent 27dd610 commit 1eb41d8
Showing 1 changed file with 1 addition and 0 deletions.
1 change: 1 addition & 0 deletions darts/models/forecasting/transformer_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,7 @@ def _prediction_step(self, src: torch.Tensor, tgt: torch.Tensor):
# Allow teacher forcing
def training_step(self, train_batch, batch_idx) -> torch.Tensor:
"""performs the training step"""
train_batch = list(train_batch)
future_targets = train_batch[-1]
train_batch.append(future_targets)
return super().training_step(train_batch, batch_idx)
Expand Down

0 comments on commit 1eb41d8

Please sign in to comment.