From 084541f8c8bd056d44234831581e828e1ef8ea1a Mon Sep 17 00:00:00 2001 From: Jan Fidor Date: Sat, 22 Jul 2023 22:48:38 +0200 Subject: [PATCH 01/10] add teacher forcing to transformer model --- darts/models/forecasting/transformer_model.py | 99 +++++++++++++++---- 1 file changed, 80 insertions(+), 19 deletions(-) diff --git a/darts/models/forecasting/transformer_model.py b/darts/models/forecasting/transformer_model.py index 8e578cff88..d022d907b7 100644 --- a/darts/models/forecasting/transformer_model.py +++ b/darts/models/forecasting/transformer_model.py @@ -8,6 +8,7 @@ import torch import torch.nn as nn +import torch.nn.functional as F from darts.logging import get_logger, raise_if, raise_if_not, raise_log from darts.models.components import glu_variants, layer_norm_variants @@ -184,7 +185,8 @@ def __init__( self.nr_params = nr_params self.target_length = self.output_chunk_length - self.encoder = nn.Linear(input_size, d_model) + self.encoder_src = nn.Linear(input_size, d_model) + self.encoder_tgt = nn.Linear(self.target_size, d_model) self.positional_encoding = _PositionalEncoding( d_model, dropout, self.input_chunk_length ) @@ -276,47 +278,106 @@ def __init__( custom_decoder=custom_decoder, ) - self.decoder = nn.Linear( - d_model, self.output_chunk_length * self.target_size * self.nr_params - ) + self.decoder = nn.Linear(d_model, self.target_size * self.nr_params) - def _create_transformer_inputs(self, data): + def _permute_transformer_inputs(self, data): # '_TimeSeriesSequentialDataset' stores time series in the # (batch_size, input_chunk_length, input_size) format. PyTorch's nn.Transformer # module needs it the (input_chunk_length, batch_size, input_size) format. # Therefore, the first two dimensions need to be swapped. - src = data.permute(1, 0, 2) - tgt = src[-1:, :, :] - - return src, tgt + return data.permute(1, 0, 2) def forward(self, x_in: Tuple): + """ + When teacher forcing, x_in = (past_target + past_covariates, static_covariates, future_targets) + When inference, x_in = (past_target + past_covariates, static_covariates) + + """ + data = x_in[0] + + start_token = self._permute_transformer_inputs(data[:, -1:, :]) + if len(x_in) == 3: + src, _, tgt = x_in + src = self._permute_transformer_inputs(src) + tgt_permuted = self._permute_transformer_inputs(tgt) + tgt_padded = F.pad(tgt_permuted, (0, self.input_size - self.target_size)) + tgt = torch.cat([start_token, tgt_padded], dim=0) + return self._prediction_step(src, tgt)[:, :-1, :, :] + data, _ = x_in - # Here we create 'src' and 'tgt', the inputs for the encoder and decoder - # side of the Transformer architecture - src, tgt = self._create_transformer_inputs(data) + src = self._permute_transformer_inputs(data) + tgt = start_token + + predictions = [] + for _ in range(self.output_chunk_length): + pred = self._prediction_step(src, tgt)[:, -1, :, :] + predictions.append(pred) + tgt = torch.cat( + [tgt, pred.mean(dim=2).unsqueeze(dim=0)], + dim=0, + ) # take average of quantiles + return torch.stack(predictions, dim=1) + + def _prediction_step(self, src: torch.Tensor, tgt: torch.Tensor): + batch_size = src.shape[1] # "math.sqrt(self.input_size)" is a normalization factor # see section 3.2.1 in 'Attention is All you Need' by Vaswani et al. (2017) - src = self.encoder(src) * math.sqrt(self.input_size) + src = self.encoder_src(src) * math.sqrt(self.input_size) src = self.positional_encoding(src) - tgt = self.encoder(tgt) * math.sqrt(self.input_size) + tgt = self.encoder_src(tgt) * math.sqrt(self.input_size) tgt = self.positional_encoding(tgt) - x = self.transformer(src=src, tgt=tgt) + tgt_mask = self.transformer.generate_square_subsequent_mask( + tgt.shape[0], src.get_device() + ) + x = self.transformer(src=src, tgt=tgt, tgt_mask=tgt_mask) out = self.decoder(x) # Here we change the data format # from (1, batch_size, output_chunk_length * output_size) # to (batch_size, output_chunk_length, output_size, nr_params) - predictions = out[0, :, :] - predictions = predictions.view( - -1, self.target_length, self.target_size, self.nr_params - ) + predictions = out.permute(1, 0, 2) + predictions = predictions.view(batch_size, -1, self.target_size, self.nr_params) return predictions + # Allow teacher forcing + def training_step(self, train_batch, batch_idx) -> torch.Tensor: + """performs the training step""" + future_targets = train_batch[-1] + train_batch.append(future_targets) + return super().training_step(train_batch, batch_idx) + + def _produce_train_output(self, input_batch: Tuple): + """ + Feeds PastCovariatesTorchModel with input and output chunks of a PastCovariatesSequentialDataset for + training. + + Parameters: if len(inp + # print([x.shape if x is not None else x for x in train_batch], "TRAIN")ut_batch) != 4: + a = 1 + ---------- + a = 1 + input_batch + ``(past_target, past_covariates, static_covariates)`` + """ + + past_target, past_covariates, static_covariates = input_batch[:3] + # Currently all our PastCovariates models require past target and covariates concatenated + inpt = [ + torch.cat([past_target, past_covariates], dim=2) + if past_covariates is not None + else past_target, + static_covariates, + ] + + # add future targets when teacher forcing + if len(input_batch) == 4: + inpt.append(input_batch[-1]) + return self(inpt) + class TransformerModel(PastCovariatesTorchModel): def __init__( From 492c3e535f7914d389111143363e10b757faa09d Mon Sep 17 00:00:00 2001 From: Jan Fidor Date: Sun, 23 Jul 2023 13:52:38 +0200 Subject: [PATCH 02/10] change encoder to the original implementation (the best metrics) and use d_model as normalization factor --- darts/models/forecasting/transformer_model.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/darts/models/forecasting/transformer_model.py b/darts/models/forecasting/transformer_model.py index d022d907b7..6a61866154 100644 --- a/darts/models/forecasting/transformer_model.py +++ b/darts/models/forecasting/transformer_model.py @@ -184,9 +184,9 @@ def __init__( self.target_size = output_size self.nr_params = nr_params self.target_length = self.output_chunk_length + self.d_model = d_model - self.encoder_src = nn.Linear(input_size, d_model) - self.encoder_tgt = nn.Linear(self.target_size, d_model) + self.encoder = nn.Linear(input_size, d_model) self.positional_encoding = _PositionalEncoding( d_model, dropout, self.input_chunk_length ) @@ -323,10 +323,10 @@ def _prediction_step(self, src: torch.Tensor, tgt: torch.Tensor): batch_size = src.shape[1] # "math.sqrt(self.input_size)" is a normalization factor # see section 3.2.1 in 'Attention is All you Need' by Vaswani et al. (2017) - src = self.encoder_src(src) * math.sqrt(self.input_size) - src = self.positional_encoding(src) + src = self.encoder(src) * math.sqrt(self.d_model) + tgt = self.encoder(tgt) * math.sqrt(self.d_model) - tgt = self.encoder_src(tgt) * math.sqrt(self.input_size) + src = self.positional_encoding(src) tgt = self.positional_encoding(tgt) tgt_mask = self.transformer.generate_square_subsequent_mask( From 30ccd785fcce5ab5aba9d5d735d5c6c26005c780 Mon Sep 17 00:00:00 2001 From: Jan Fidor Date: Sun, 23 Jul 2023 15:12:36 +0200 Subject: [PATCH 03/10] add 0 padding to targets, to compensate for missing past covariates --- darts/models/forecasting/transformer_model.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/darts/models/forecasting/transformer_model.py b/darts/models/forecasting/transformer_model.py index 6a61866154..d8061e24ae 100644 --- a/darts/models/forecasting/transformer_model.py +++ b/darts/models/forecasting/transformer_model.py @@ -295,26 +295,28 @@ def forward(self, x_in: Tuple): """ data = x_in[0] - start_token = self._permute_transformer_inputs(data[:, -1:, :]) + pad_size = (0, self.input_size - self.target_size) + start_token = self._permute_transformer_inputs(data[:, -1:, : self.target_size]) + start_token_padded = F.pad(start_token, pad_size) if len(x_in) == 3: src, _, tgt = x_in src = self._permute_transformer_inputs(src) tgt_permuted = self._permute_transformer_inputs(tgt) - tgt_padded = F.pad(tgt_permuted, (0, self.input_size - self.target_size)) - tgt = torch.cat([start_token, tgt_padded], dim=0) + tgt_padded = F.pad(tgt_permuted, pad_size) + tgt = torch.cat([start_token_padded, tgt_padded], dim=0) return self._prediction_step(src, tgt)[:, :-1, :, :] data, _ = x_in src = self._permute_transformer_inputs(data) - tgt = start_token + tgt = start_token_padded predictions = [] for _ in range(self.output_chunk_length): pred = self._prediction_step(src, tgt)[:, -1, :, :] predictions.append(pred) tgt = torch.cat( - [tgt, pred.mean(dim=2).unsqueeze(dim=0)], + [tgt, F.pad(pred.mean(dim=2).unsqueeze(dim=0), pad_size)], dim=0, ) # take average of quantiles return torch.stack(predictions, dim=1) From fd138309e72e366351e56d9d4be7c5c38a47c421 Mon Sep 17 00:00:00 2001 From: Jan Fidor Date: Sun, 23 Jul 2023 15:19:58 +0200 Subject: [PATCH 04/10] add comments --- darts/models/forecasting/transformer_model.py | 23 +++++++++---------- 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/darts/models/forecasting/transformer_model.py b/darts/models/forecasting/transformer_model.py index d8061e24ae..d17082cb08 100644 --- a/darts/models/forecasting/transformer_model.py +++ b/darts/models/forecasting/transformer_model.py @@ -289,15 +289,16 @@ def _permute_transformer_inputs(self, data): def forward(self, x_in: Tuple): """ - When teacher forcing, x_in = (past_target + past_covariates, static_covariates, future_targets) - When inference, x_in = (past_target + past_covariates, static_covariates) - + During training (teacher forcing) x_in = tuple(past_target + past_covariates, static_covariates, future_targets) + During inference x_in = tuple(past_target + past_covariates, static_covariates) """ data = x_in[0] - pad_size = (0, self.input_size - self.target_size) + + # start token consists only of target series, past covariates are substituted with 0 padding start_token = self._permute_transformer_inputs(data[:, -1:, : self.target_size]) start_token_padded = F.pad(start_token, pad_size) + if len(x_in) == 3: src, _, tgt = x_in src = self._permute_transformer_inputs(src) @@ -340,7 +341,7 @@ def _prediction_step(self, src: torch.Tensor, tgt: torch.Tensor): # Here we change the data format # from (1, batch_size, output_chunk_length * output_size) # to (batch_size, output_chunk_length, output_size, nr_params) - predictions = out.permute(1, 0, 2) + predictions = self._permute_transformer_inputs(out) predictions = predictions.view(batch_size, -1, self.target_size, self.nr_params) return predictions @@ -357,13 +358,11 @@ def _produce_train_output(self, input_batch: Tuple): Feeds PastCovariatesTorchModel with input and output chunks of a PastCovariatesSequentialDataset for training. - Parameters: if len(inp - # print([x.shape if x is not None else x for x in train_batch], "TRAIN")ut_batch) != 4: - a = 1 - ---------- - a = 1 + Parameters: input_batch - ``(past_target, past_covariates, static_covariates)`` + ``(past_target, past_covariates, static_covariates, future_target)`` during training + + ``(past_target, past_covariates, static_covariates)`` during validation (not teacher forced) """ past_target, past_covariates, static_covariates = input_batch[:3] @@ -375,7 +374,7 @@ def _produce_train_output(self, input_batch: Tuple): static_covariates, ] - # add future targets when teacher forcing + # add future targets when training (teacher forcing) if len(input_batch) == 4: inpt.append(input_batch[-1]) return self(inpt) From 915fbb5c276f863a08ebd03d5ce445ed13639eca Mon Sep 17 00:00:00 2001 From: Jan Fidor Date: Sun, 23 Jul 2023 17:03:06 +0200 Subject: [PATCH 05/10] add bug fix after merging --- darts/models/forecasting/transformer_model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/darts/models/forecasting/transformer_model.py b/darts/models/forecasting/transformer_model.py index d17082cb08..6d494cfdb8 100644 --- a/darts/models/forecasting/transformer_model.py +++ b/darts/models/forecasting/transformer_model.py @@ -349,6 +349,7 @@ def _prediction_step(self, src: torch.Tensor, tgt: torch.Tensor): # Allow teacher forcing def training_step(self, train_batch, batch_idx) -> torch.Tensor: """performs the training step""" + train_batch = list(train_batch) future_targets = train_batch[-1] train_batch.append(future_targets) return super().training_step(train_batch, batch_idx) From 7971bb8cacf50d90ea5294590128aa44e6fc7187 Mon Sep 17 00:00:00 2001 From: Jan Fidor Date: Sat, 22 Jul 2023 22:48:38 +0200 Subject: [PATCH 06/10] add teacher forcing to transformer model --- darts/models/forecasting/transformer_model.py | 99 +++++++++++++++---- 1 file changed, 80 insertions(+), 19 deletions(-) diff --git a/darts/models/forecasting/transformer_model.py b/darts/models/forecasting/transformer_model.py index 8e578cff88..d022d907b7 100644 --- a/darts/models/forecasting/transformer_model.py +++ b/darts/models/forecasting/transformer_model.py @@ -8,6 +8,7 @@ import torch import torch.nn as nn +import torch.nn.functional as F from darts.logging import get_logger, raise_if, raise_if_not, raise_log from darts.models.components import glu_variants, layer_norm_variants @@ -184,7 +185,8 @@ def __init__( self.nr_params = nr_params self.target_length = self.output_chunk_length - self.encoder = nn.Linear(input_size, d_model) + self.encoder_src = nn.Linear(input_size, d_model) + self.encoder_tgt = nn.Linear(self.target_size, d_model) self.positional_encoding = _PositionalEncoding( d_model, dropout, self.input_chunk_length ) @@ -276,47 +278,106 @@ def __init__( custom_decoder=custom_decoder, ) - self.decoder = nn.Linear( - d_model, self.output_chunk_length * self.target_size * self.nr_params - ) + self.decoder = nn.Linear(d_model, self.target_size * self.nr_params) - def _create_transformer_inputs(self, data): + def _permute_transformer_inputs(self, data): # '_TimeSeriesSequentialDataset' stores time series in the # (batch_size, input_chunk_length, input_size) format. PyTorch's nn.Transformer # module needs it the (input_chunk_length, batch_size, input_size) format. # Therefore, the first two dimensions need to be swapped. - src = data.permute(1, 0, 2) - tgt = src[-1:, :, :] - - return src, tgt + return data.permute(1, 0, 2) def forward(self, x_in: Tuple): + """ + When teacher forcing, x_in = (past_target + past_covariates, static_covariates, future_targets) + When inference, x_in = (past_target + past_covariates, static_covariates) + + """ + data = x_in[0] + + start_token = self._permute_transformer_inputs(data[:, -1:, :]) + if len(x_in) == 3: + src, _, tgt = x_in + src = self._permute_transformer_inputs(src) + tgt_permuted = self._permute_transformer_inputs(tgt) + tgt_padded = F.pad(tgt_permuted, (0, self.input_size - self.target_size)) + tgt = torch.cat([start_token, tgt_padded], dim=0) + return self._prediction_step(src, tgt)[:, :-1, :, :] + data, _ = x_in - # Here we create 'src' and 'tgt', the inputs for the encoder and decoder - # side of the Transformer architecture - src, tgt = self._create_transformer_inputs(data) + src = self._permute_transformer_inputs(data) + tgt = start_token + + predictions = [] + for _ in range(self.output_chunk_length): + pred = self._prediction_step(src, tgt)[:, -1, :, :] + predictions.append(pred) + tgt = torch.cat( + [tgt, pred.mean(dim=2).unsqueeze(dim=0)], + dim=0, + ) # take average of quantiles + return torch.stack(predictions, dim=1) + + def _prediction_step(self, src: torch.Tensor, tgt: torch.Tensor): + batch_size = src.shape[1] # "math.sqrt(self.input_size)" is a normalization factor # see section 3.2.1 in 'Attention is All you Need' by Vaswani et al. (2017) - src = self.encoder(src) * math.sqrt(self.input_size) + src = self.encoder_src(src) * math.sqrt(self.input_size) src = self.positional_encoding(src) - tgt = self.encoder(tgt) * math.sqrt(self.input_size) + tgt = self.encoder_src(tgt) * math.sqrt(self.input_size) tgt = self.positional_encoding(tgt) - x = self.transformer(src=src, tgt=tgt) + tgt_mask = self.transformer.generate_square_subsequent_mask( + tgt.shape[0], src.get_device() + ) + x = self.transformer(src=src, tgt=tgt, tgt_mask=tgt_mask) out = self.decoder(x) # Here we change the data format # from (1, batch_size, output_chunk_length * output_size) # to (batch_size, output_chunk_length, output_size, nr_params) - predictions = out[0, :, :] - predictions = predictions.view( - -1, self.target_length, self.target_size, self.nr_params - ) + predictions = out.permute(1, 0, 2) + predictions = predictions.view(batch_size, -1, self.target_size, self.nr_params) return predictions + # Allow teacher forcing + def training_step(self, train_batch, batch_idx) -> torch.Tensor: + """performs the training step""" + future_targets = train_batch[-1] + train_batch.append(future_targets) + return super().training_step(train_batch, batch_idx) + + def _produce_train_output(self, input_batch: Tuple): + """ + Feeds PastCovariatesTorchModel with input and output chunks of a PastCovariatesSequentialDataset for + training. + + Parameters: if len(inp + # print([x.shape if x is not None else x for x in train_batch], "TRAIN")ut_batch) != 4: + a = 1 + ---------- + a = 1 + input_batch + ``(past_target, past_covariates, static_covariates)`` + """ + + past_target, past_covariates, static_covariates = input_batch[:3] + # Currently all our PastCovariates models require past target and covariates concatenated + inpt = [ + torch.cat([past_target, past_covariates], dim=2) + if past_covariates is not None + else past_target, + static_covariates, + ] + + # add future targets when teacher forcing + if len(input_batch) == 4: + inpt.append(input_batch[-1]) + return self(inpt) + class TransformerModel(PastCovariatesTorchModel): def __init__( From c3cd5821ddef514bd73a50c8e1e65d6b72c53379 Mon Sep 17 00:00:00 2001 From: Jan Fidor Date: Sun, 23 Jul 2023 13:52:38 +0200 Subject: [PATCH 07/10] change encoder to the original implementation (the best metrics) and use d_model as normalization factor --- darts/models/forecasting/transformer_model.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/darts/models/forecasting/transformer_model.py b/darts/models/forecasting/transformer_model.py index d022d907b7..6a61866154 100644 --- a/darts/models/forecasting/transformer_model.py +++ b/darts/models/forecasting/transformer_model.py @@ -184,9 +184,9 @@ def __init__( self.target_size = output_size self.nr_params = nr_params self.target_length = self.output_chunk_length + self.d_model = d_model - self.encoder_src = nn.Linear(input_size, d_model) - self.encoder_tgt = nn.Linear(self.target_size, d_model) + self.encoder = nn.Linear(input_size, d_model) self.positional_encoding = _PositionalEncoding( d_model, dropout, self.input_chunk_length ) @@ -323,10 +323,10 @@ def _prediction_step(self, src: torch.Tensor, tgt: torch.Tensor): batch_size = src.shape[1] # "math.sqrt(self.input_size)" is a normalization factor # see section 3.2.1 in 'Attention is All you Need' by Vaswani et al. (2017) - src = self.encoder_src(src) * math.sqrt(self.input_size) - src = self.positional_encoding(src) + src = self.encoder(src) * math.sqrt(self.d_model) + tgt = self.encoder(tgt) * math.sqrt(self.d_model) - tgt = self.encoder_src(tgt) * math.sqrt(self.input_size) + src = self.positional_encoding(src) tgt = self.positional_encoding(tgt) tgt_mask = self.transformer.generate_square_subsequent_mask( From 11979b479bea0c101f32f7fca28f0115356e6686 Mon Sep 17 00:00:00 2001 From: Jan Fidor Date: Sun, 23 Jul 2023 15:12:36 +0200 Subject: [PATCH 08/10] add 0 padding to targets, to compensate for missing past covariates --- darts/models/forecasting/transformer_model.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/darts/models/forecasting/transformer_model.py b/darts/models/forecasting/transformer_model.py index 6a61866154..d8061e24ae 100644 --- a/darts/models/forecasting/transformer_model.py +++ b/darts/models/forecasting/transformer_model.py @@ -295,26 +295,28 @@ def forward(self, x_in: Tuple): """ data = x_in[0] - start_token = self._permute_transformer_inputs(data[:, -1:, :]) + pad_size = (0, self.input_size - self.target_size) + start_token = self._permute_transformer_inputs(data[:, -1:, : self.target_size]) + start_token_padded = F.pad(start_token, pad_size) if len(x_in) == 3: src, _, tgt = x_in src = self._permute_transformer_inputs(src) tgt_permuted = self._permute_transformer_inputs(tgt) - tgt_padded = F.pad(tgt_permuted, (0, self.input_size - self.target_size)) - tgt = torch.cat([start_token, tgt_padded], dim=0) + tgt_padded = F.pad(tgt_permuted, pad_size) + tgt = torch.cat([start_token_padded, tgt_padded], dim=0) return self._prediction_step(src, tgt)[:, :-1, :, :] data, _ = x_in src = self._permute_transformer_inputs(data) - tgt = start_token + tgt = start_token_padded predictions = [] for _ in range(self.output_chunk_length): pred = self._prediction_step(src, tgt)[:, -1, :, :] predictions.append(pred) tgt = torch.cat( - [tgt, pred.mean(dim=2).unsqueeze(dim=0)], + [tgt, F.pad(pred.mean(dim=2).unsqueeze(dim=0), pad_size)], dim=0, ) # take average of quantiles return torch.stack(predictions, dim=1) From 5cfd8a444c09e18cffb0aaf131ff50b244651dd8 Mon Sep 17 00:00:00 2001 From: Jan Fidor Date: Sun, 23 Jul 2023 15:19:58 +0200 Subject: [PATCH 09/10] add comments --- darts/models/forecasting/transformer_model.py | 23 +++++++++---------- 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/darts/models/forecasting/transformer_model.py b/darts/models/forecasting/transformer_model.py index d8061e24ae..d17082cb08 100644 --- a/darts/models/forecasting/transformer_model.py +++ b/darts/models/forecasting/transformer_model.py @@ -289,15 +289,16 @@ def _permute_transformer_inputs(self, data): def forward(self, x_in: Tuple): """ - When teacher forcing, x_in = (past_target + past_covariates, static_covariates, future_targets) - When inference, x_in = (past_target + past_covariates, static_covariates) - + During training (teacher forcing) x_in = tuple(past_target + past_covariates, static_covariates, future_targets) + During inference x_in = tuple(past_target + past_covariates, static_covariates) """ data = x_in[0] - pad_size = (0, self.input_size - self.target_size) + + # start token consists only of target series, past covariates are substituted with 0 padding start_token = self._permute_transformer_inputs(data[:, -1:, : self.target_size]) start_token_padded = F.pad(start_token, pad_size) + if len(x_in) == 3: src, _, tgt = x_in src = self._permute_transformer_inputs(src) @@ -340,7 +341,7 @@ def _prediction_step(self, src: torch.Tensor, tgt: torch.Tensor): # Here we change the data format # from (1, batch_size, output_chunk_length * output_size) # to (batch_size, output_chunk_length, output_size, nr_params) - predictions = out.permute(1, 0, 2) + predictions = self._permute_transformer_inputs(out) predictions = predictions.view(batch_size, -1, self.target_size, self.nr_params) return predictions @@ -357,13 +358,11 @@ def _produce_train_output(self, input_batch: Tuple): Feeds PastCovariatesTorchModel with input and output chunks of a PastCovariatesSequentialDataset for training. - Parameters: if len(inp - # print([x.shape if x is not None else x for x in train_batch], "TRAIN")ut_batch) != 4: - a = 1 - ---------- - a = 1 + Parameters: input_batch - ``(past_target, past_covariates, static_covariates)`` + ``(past_target, past_covariates, static_covariates, future_target)`` during training + + ``(past_target, past_covariates, static_covariates)`` during validation (not teacher forced) """ past_target, past_covariates, static_covariates = input_batch[:3] @@ -375,7 +374,7 @@ def _produce_train_output(self, input_batch: Tuple): static_covariates, ] - # add future targets when teacher forcing + # add future targets when training (teacher forcing) if len(input_batch) == 4: inpt.append(input_batch[-1]) return self(inpt) From 2d609eb79a78801cf28ba0da9d6ec4d225c0bfb9 Mon Sep 17 00:00:00 2001 From: Jan Fidor Date: Sun, 23 Jul 2023 17:03:06 +0200 Subject: [PATCH 10/10] add bug fix after merging --- darts/models/forecasting/transformer_model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/darts/models/forecasting/transformer_model.py b/darts/models/forecasting/transformer_model.py index d17082cb08..6d494cfdb8 100644 --- a/darts/models/forecasting/transformer_model.py +++ b/darts/models/forecasting/transformer_model.py @@ -349,6 +349,7 @@ def _prediction_step(self, src: torch.Tensor, tgt: torch.Tensor): # Allow teacher forcing def training_step(self, train_batch, batch_idx) -> torch.Tensor: """performs the training step""" + train_batch = list(train_batch) future_targets = train_batch[-1] train_batch.append(future_targets) return super().training_step(train_batch, batch_idx)