Skip to content

Commit

Permalink
fix_rnn_models
Browse files Browse the repository at this point in the history
  • Loading branch information
elephaint committed Sep 27, 2024
1 parent b3fafc3 commit 87af3ac
Show file tree
Hide file tree
Showing 11 changed files with 121 additions and 82 deletions.
8 changes: 5 additions & 3 deletions nbs/models.deepar.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,8 @@
" input_encoder = 1 + self.futr_exog_size + self.stat_exog_size\n",
"\n",
" # Instantiate model\n",
" self.rnn_state = None\n",
" self.maintain_state = False\n",
" self.hist_encoder = nn.LSTM(input_size=input_encoder,\n",
" hidden_size=self.encoder_hidden_size,\n",
" num_layers=self.encoder_n_layers,\n",
Expand All @@ -304,7 +306,7 @@
" encoder_input = torch.cat((encoder_input, futr_exog), dim=2)\n",
"\n",
" if self.stat_exog_size > 0:\n",
" stat_exog = stat_exog.unsqueeze(1).repeat(1, input_size, 1) # [B, S] -> [B, input_size-1, S]\n",
" stat_exog = stat_exog.unsqueeze(1).repeat(1, input_size, 1) # [B, S] -> [B, input_size-1, S]\n",
" encoder_input = torch.cat((encoder_input, stat_exog), dim=2)\n",
"\n",
" # RNN forward\n",
Expand All @@ -314,13 +316,13 @@
" rnn_state = None\n",
"\n",
" hidden_state, rnn_state = self.hist_encoder(encoder_input, \n",
" rnn_state) # [B, input_size-1, rnn_hidden_state]\n",
" rnn_state) # [B, input_size-1, rnn_hidden_state]\n",
"\n",
" if self.maintain_state:\n",
" self.rnn_state = rnn_state\n",
"\n",
" # Decoder forward\n",
" output = self.decoder(hidden_state) # [B, input_size-1, output_size]\n",
" output = self.decoder(hidden_state) # [B, input_size-1, output_size]\n",
"\n",
" # Return only horizon part\n",
" return output[:, -self.h:]"
Expand Down
38 changes: 18 additions & 20 deletions nbs/models.dilated_rnn.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@
"\n",
" def __init__(self,\n",
" h: int,\n",
" input_size: int = -1,\n",
" input_size: int,\n",
" inference_input_size: int = -1,\n",
" cell_type: str = 'LSTM',\n",
" dilations: List[List[int]] = [[1, 2], [4, 8]],\n",
Expand All @@ -419,6 +419,7 @@
" futr_exog_list = None,\n",
" hist_exog_list = None,\n",
" stat_exog_list = None,\n",
" exclude_insample_y = False,\n",
" loss = MAE(),\n",
" valid_loss = None,\n",
" max_steps: int = 1000,\n",
Expand All @@ -444,7 +445,10 @@
" super(DilatedRNN, self).__init__(\n",
" h=h,\n",
" input_size=input_size,\n",
" inference_input_size=inference_input_size,\n",
" futr_exog_list=futr_exog_list,\n",
" hist_exog_list=hist_exog_list,\n",
" stat_exog_list=stat_exog_list,\n",
" exclude_insample_y = exclude_insample_y,\n",
" loss=loss,\n",
" valid_loss=valid_loss,\n",
" max_steps=max_steps,\n",
Expand All @@ -459,12 +463,9 @@
" start_padding_enabled=start_padding_enabled,\n",
" step_size=step_size,\n",
" scaler_type=scaler_type,\n",
" futr_exog_list=futr_exog_list,\n",
" hist_exog_list=hist_exog_list,\n",
" stat_exog_list=stat_exog_list,\n",
" random_seed=random_seed,\n",
" num_workers_loader=num_workers_loader,\n",
" drop_last_loader=drop_last_loader,\n",
" random_seed=random_seed,\n",
" optimizer=optimizer,\n",
" optimizer_kwargs=optimizer_kwargs,\n",
" lr_scheduler=lr_scheduler,\n",
Expand Down Expand Up @@ -502,11 +503,11 @@
" self.rnn_stack = nn.Sequential(*layers)\n",
"\n",
" # Context adapter\n",
" self.context_adapter = nn.Linear(in_features=self.input_size,\n",
" out_features=h)\n",
" self.context_adapter = nn.Linear(in_features=self.encoder_hidden_size,\n",
" out_features=self.context_size * h)\n",
"\n",
" # Decoder MLP\n",
" self.mlp_decoder = MLP(in_features=self.encoder_hidden_size + self.futr_exog_size,\n",
" self.mlp_decoder = MLP(in_features=self.context_size * h + self.futr_exog_size,\n",
" out_features=self.loss.outputsize_multiplier,\n",
" hidden_size=self.decoder_hidden_size,\n",
" num_layers=self.decoder_layers,\n",
Expand All @@ -522,17 +523,17 @@
" stat_exog = windows_batch['stat_exog'] # [B, S]\n",
"\n",
" # Concatenate y, historic and static inputs \n",
" batch_size, input_size = encoder_input.shape[:2]\n",
" batch_size, seq_len = encoder_input.shape[:2]\n",
" if self.hist_exog_size > 0:\n",
" encoder_input = torch.cat((encoder_input, hist_exog), dim=2) # [B, L, 1] + [B, L, X] -> [B, L, 1 + X]\n",
"\n",
" if self.stat_exog_size > 0:\n",
" stat_exog = stat_exog.unsqueeze(1).repeat(1, input_size, 1) # [B, S] -> [B, L, S]\n",
" stat_exog = stat_exog.unsqueeze(1).repeat(1, seq_len, 1) # [B, S] -> [B, L, S]\n",
" encoder_input = torch.cat((encoder_input, stat_exog), dim=2) # [B, L, 1 + X] + [B, L, S] -> [B, L, 1 + X + S]\n",
"\n",
" if self.futr_exog_size > 0:\n",
" encoder_input = torch.cat((encoder_input, \n",
" futr_exog[:, :input_size]), dim=2) # [B, L, 1 + X + S] + [B, L, F] -> [B, L, 1 + X + S + F]\n",
" futr_exog[:, :seq_len]), dim=2) # [B, L, 1 + X + S] + [B, L, F] -> [B, L, 1 + X + S + F]\n",
"\n",
" # DilatedRNN forward\n",
" for layer_num in range(len(self.rnn_stack)):\n",
Expand All @@ -543,20 +544,17 @@
" encoder_input = output\n",
"\n",
" # Context adapter\n",
" output = output.permute(0, 2, 1) # [B, L, C] -> [B, C, L]\n",
" context = self.context_adapter(output) # [B, C, L] -> [B, C, h]\n",
" context = self.context_adapter(output) # [B, L, C] -> [B, L, context_size * h]\n",
"\n",
" # Residual connection with futr_exog\n",
" if self.futr_exog_size > 0:\n",
" futr_exog_futr = futr_exog[:, input_size:].swapaxes(1, 2) # [B, L + h, F] -> [B, F, h] \n",
" context = torch.cat((context, futr_exog_futr), dim=1) # [B, C, h] + [B, F, h] = [B, C + F, h]\n",
"\n",
" context = context.swapaxes(1, 2) # [B, C + F, h] -> [B, h, C + F]\n",
" context = torch.cat((context, futr_exog[:, :seq_len]), \n",
" dim=-1) # [B, L, context_size * h] + [B, L, F] = [B, L, context_size * h + F]\n",
"\n",
" # Final forecast\n",
" output = self.mlp_decoder(context) # [B, h, C + F] -> [B, h, n_output]\n",
" output = self.mlp_decoder(context) # [B, L, context_size * h + F] -> [B, L, n_output]\n",
" \n",
" return output"
" return output[:, -self.h:]"
]
},
{
Expand Down
24 changes: 16 additions & 8 deletions nbs/models.gru.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,8 @@
" futr_exog_list = None,\n",
" hist_exog_list = None,\n",
" stat_exog_list = None,\n",
" exclude_insample_y = False,\n",
" recurrent = False,\n",
" loss = MAE(),\n",
" valid_loss = None,\n",
" max_steps: int = 1000,\n",
Expand All @@ -178,10 +180,16 @@
" lr_scheduler = None,\n",
" lr_scheduler_kwargs = None,\n",
" **trainer_kwargs):\n",
" \n",
" self.RECURRENT = recurrent\n",
"\n",
" super(GRU, self).__init__(\n",
" h=h,\n",
" input_size=input_size,\n",
" # inference_input_size=inference_input_size,\n",
" futr_exog_list=futr_exog_list,\n",
" hist_exog_list=hist_exog_list,\n",
" stat_exog_list=stat_exog_list,\n",
" exclude_insample_y = exclude_insample_y,\n",
" loss=loss,\n",
" valid_loss=valid_loss,\n",
" max_steps=max_steps,\n",
Expand All @@ -196,12 +204,9 @@
" start_padding_enabled=start_padding_enabled,\n",
" step_size=step_size,\n",
" scaler_type=scaler_type,\n",
" futr_exog_list=futr_exog_list,\n",
" hist_exog_list=hist_exog_list,\n",
" stat_exog_list=stat_exog_list,\n",
" random_seed=random_seed,\n",
" num_workers_loader=num_workers_loader,\n",
" drop_last_loader=drop_last_loader,\n",
" random_seed=random_seed,\n",
" optimizer=optimizer,\n",
" optimizer_kwargs=optimizer_kwargs,\n",
" lr_scheduler=lr_scheduler,\n",
Expand All @@ -227,6 +232,7 @@
"\n",
" # Instantiate model\n",
" self.rnn_state = None\n",
" self.maintain_state = False\n",
" self.hist_encoder = nn.GRU(input_size=input_encoder,\n",
" hidden_size=self.encoder_hidden_size,\n",
" num_layers=self.encoder_n_layers,\n",
Expand Down Expand Up @@ -263,7 +269,8 @@
" encoder_input = torch.cat((encoder_input, stat_exog), dim=2) # [B, seq_len, 1 + X] + [B, seq_len, S] -> [B, seq_len, 1 + X + S]\n",
"\n",
" if self.futr_exog_size > 0:\n",
" encoder_input = torch.cat((encoder_input, futr_exog), dim=2) # [B, seq_len, 1 + X + S] + [B, seq_len, F] -> [B, seq_len, 1 + X + S + F]\n",
" encoder_input = torch.cat((encoder_input, \n",
" futr_exog[:, :seq_len]), dim=2) # [B, seq_len, 1 + X + S] + [B, seq_len, F] -> [B, seq_len, 1 + X + S + F]\n",
"\n",
" # RNN forward\n",
" if self.maintain_state:\n",
Expand All @@ -272,7 +279,7 @@
" rnn_state = None\n",
" \n",
" hidden_state, rnn_state = self.hist_encoder(encoder_input, \n",
" rnn_state) # [B, seq_len, rnn_hidden_state]\n",
" rnn_state) # [B, seq_len, rnn_hidden_state]\n",
" if self.maintain_state:\n",
" self.rnn_state = rnn_state\n",
"\n",
Expand All @@ -281,7 +288,8 @@
"\n",
" # Residual connection with futr_exog\n",
" if self.futr_exog_size > 0:\n",
" context = torch.cat((context, futr_exog), dim=-1) # [B, seq_len, context_size * h] + [B, seq_len, F] = [B, seq_len, context_size * h + F]\n",
" context = torch.cat((context, futr_exog[:, :seq_len]), \n",
" dim=-1) # [B, seq_len, context_size * h] + [B, seq_len, F] = [B, seq_len, context_size * h + F]\n",
"\n",
" # Final forecast\n",
" output = self.mlp_decoder(context) # [B, seq_len, context_size * h + F] -> [B, seq_len, n_output]\n",
Expand Down
2 changes: 1 addition & 1 deletion nbs/models.itransformer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@
"logging.getLogger(\"lightning_fabric\").setLevel(logging.ERROR)\n",
"with warnings.catch_warnings():\n",
" warnings.simplefilter(\"ignore\")\n",
" check_model(iTransformer)"
" check_model(iTransformer, [\"airpassengers\"])"
]
},
{
Expand Down
13 changes: 10 additions & 3 deletions nbs/models.lstm.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@
" hist_exog_list = None,\n",
" stat_exog_list = None,\n",
" exclude_insample_y = False,\n",
" recurrent = False,\n",
" loss = MAE(),\n",
" valid_loss = None,\n",
" max_steps: int = 1000,\n",
Expand All @@ -174,6 +175,9 @@
" lr_scheduler = None,\n",
" lr_scheduler_kwargs = None,\n",
" **trainer_kwargs):\n",
" \n",
" self.RECURRENT = recurrent\n",
" \n",
" super(LSTM, self).__init__(\n",
" h=h,\n",
" input_size=input_size,\n",
Expand Down Expand Up @@ -223,6 +227,7 @@
"\n",
" # Instantiate model\n",
" self.rnn_state = None\n",
" self.maintain_state = False\n",
" self.hist_encoder = nn.LSTM(input_size=input_encoder,\n",
" hidden_size=self.encoder_hidden_size,\n",
" num_layers=self.encoder_n_layers,\n",
Expand Down Expand Up @@ -261,7 +266,8 @@
" encoder_input = torch.cat((encoder_input, stat_exog), dim=2) # [B, seq_len, 1 + X] + [B, seq_len, S] -> [B, seq_len, 1 + X + S]\n",
"\n",
" if self.futr_exog_size > 0:\n",
" encoder_input = torch.cat((encoder_input, futr_exog), dim=2) # [B, seq_len, 1 + X + S] + [B, seq_len, F] -> [B, seq_len, 1 + X + S + F]\n",
" encoder_input = torch.cat((encoder_input, \n",
" futr_exog[:, :seq_len]), dim=2) # [B, seq_len, 1 + X + S] + [B, seq_len, F] -> [B, seq_len, 1 + X + S + F]\n",
"\n",
" # RNN forward\n",
" if self.maintain_state:\n",
Expand All @@ -279,7 +285,8 @@
"\n",
" # Residual connection with futr_exog\n",
" if self.futr_exog_size > 0:\n",
" context = torch.cat((context, futr_exog), dim=-1) # [B, seq_len, context_size * h] + [B, seq_len, F] = [B, seq_len, context_size * h + F]\n",
" context = torch.cat((context, futr_exog[:, :seq_len]), \n",
" dim=-1) # [B, seq_len, context_size * h] + [B, seq_len, F] = [B, seq_len, context_size * h + F]\n",
"\n",
" # Final forecast\n",
" output = self.mlp_decoder(context) # [B, seq_len, context_size * h + F] -> [B, seq_len, n_output]\n",
Expand Down Expand Up @@ -326,7 +333,7 @@
"logging.getLogger(\"lightning_fabric\").setLevel(logging.ERROR)\n",
"with warnings.catch_warnings():\n",
" warnings.simplefilter(\"ignore\")\n",
" check_model(LSTM)"
" check_model(LSTM, [\"airpassengers\"])"
]
},
{
Expand Down
23 changes: 15 additions & 8 deletions nbs/models.rnn.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@
" hist_exog_list = None,\n",
" stat_exog_list = None,\n",
" exclude_insample_y = False,\n",
" recurrent = False,\n",
" loss = MAE(),\n",
" valid_loss = None,\n",
" max_steps: int = 1000,\n",
Expand All @@ -181,10 +182,16 @@
" lr_scheduler = None,\n",
" lr_scheduler_kwargs = None, \n",
" **trainer_kwargs):\n",
" \n",
" self.RECURRENT = recurrent\n",
"\n",
" super(RNN, self).__init__(\n",
" h=h,\n",
" input_size=input_size,\n",
" inference_input_size=inference_input_size,\n",
" futr_exog_list=futr_exog_list,\n",
" hist_exog_list=hist_exog_list,\n",
" stat_exog_list=stat_exog_list,\n",
" exclude_insample_y = exclude_insample_y,\n",
" loss=loss,\n",
" valid_loss=valid_loss,\n",
" max_steps=max_steps,\n",
Expand All @@ -199,13 +206,9 @@
" start_padding_enabled=start_padding_enabled,\n",
" step_size=step_size,\n",
" scaler_type=scaler_type,\n",
" futr_exog_list=futr_exog_list,\n",
" hist_exog_list=hist_exog_list,\n",
" stat_exog_list=stat_exog_list,\n",
" exclude_insample_y = exclude_insample_y,\n",
" random_seed=random_seed,\n",
" num_workers_loader=num_workers_loader,\n",
" drop_last_loader=drop_last_loader,\n",
" random_seed=random_seed,\n",
" optimizer=optimizer,\n",
" optimizer_kwargs=optimizer_kwargs,\n",
" lr_scheduler=lr_scheduler,\n",
Expand All @@ -231,6 +234,8 @@
" input_encoder = 1 + self.hist_exog_size + self.stat_exog_size + self.futr_exog_size\n",
"\n",
" # Instantiate model\n",
" self.rnn_state = None\n",
" self.maintain_state = False\n",
" self.hist_encoder = nn.RNN(input_size=input_encoder,\n",
" hidden_size=self.encoder_hidden_size,\n",
" num_layers=self.encoder_n_layers,\n",
Expand Down Expand Up @@ -270,7 +275,8 @@
" encoder_input = torch.cat((encoder_input, stat_exog), dim=2) # [B, seq_len, 1 + X] + [B, seq_len, S] -> [B, seq_len, 1 + X + S]\n",
"\n",
" if self.futr_exog_size > 0:\n",
" encoder_input = torch.cat((encoder_input, futr_exog), dim=2) # [B, seq_len, 1 + X + S] + [B, seq_len, F] -> [B, seq_len, 1 + X + S + F]\n",
" encoder_input = torch.cat((encoder_input, \n",
" futr_exog[:, :seq_len]), dim=2) # [B, seq_len, 1 + X + S] + [B, seq_len, F] -> [B, seq_len, 1 + X + S + F]\n",
"\n",
" # RNN forward \n",
" if self.maintain_state:\n",
Expand All @@ -289,7 +295,8 @@
"\n",
" # Residual connection with futr_exog\n",
" if self.futr_exog_size > 0:\n",
" context = torch.cat((context, futr_exog), dim=-1) # [B, seq_len, context_size * h] + [B, seq_len, F] = [B, seq_len, context_size * h + F]\n",
" context = torch.cat((context, \n",
" futr_exog[:, :seq_len]), dim=-1) # [B, seq_len, context_size * h] + [B, seq_len, F] = [B, seq_len, context_size * h + F]\n",
"\n",
" # Final forecast\n",
" output = self.mlp_decoder(context) # [B, seq_len, context_size * h + F] -> [B, seq_len, n_output]\n",
Expand Down
2 changes: 2 additions & 0 deletions neuralforecast/models/deepar.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,8 @@ def __init__(
input_encoder = 1 + self.futr_exog_size + self.stat_exog_size

# Instantiate model
self.rnn_state = None
self.maintain_state = False
self.hist_encoder = nn.LSTM(
input_size=input_encoder,
hidden_size=self.encoder_hidden_size,
Expand Down
Loading

0 comments on commit 87af3ac

Please sign in to comment.